emacs-elpa-diffs
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[elpa] externals/llm 8ca514e53e: Add batch embeddings capability, implem


From: ELPA Syncer
Subject: [elpa] externals/llm 8ca514e53e: Add batch embeddings capability, implement for OpenAI and Ollama (#93)
Date: Mon, 4 Nov 2024 00:58:29 -0500 (EST)

branch: externals/llm
commit 8ca514e53e7e05ae782a47f4cc32c81cd3349a65
Author: Andrew Hyatt <ahyatt@gmail.com>
Commit: GitHub <noreply@github.com>

    Add batch embeddings capability, implement for OpenAI and Ollama (#93)
    
    This will implement https://github.com/ahyatt/llm/issues/92.
---
 NEWS.org                |  1 +
 llm-azure.el            |  2 +-
 llm-gemini.el           |  2 +-
 llm-integration-test.el | 67 +++++++++++++++++++++++++++++++++++++++++++------
 llm-llamacpp.el         |  2 +-
 llm-ollama.el           | 10 ++++++--
 llm-openai.el           | 23 +++++++++++++----
 llm-provider-utils.el   | 50 +++++++++++++++++++++++++++++++++---
 llm-vertex.el           |  2 +-
 llm.el                  | 40 +++++++++++++++++++++++++++++
 10 files changed, 176 insertions(+), 23 deletions(-)

diff --git a/NEWS.org b/NEWS.org
index bb93138627..8c291e6aea 100644
--- a/NEWS.org
+++ b/NEWS.org
@@ -1,4 +1,5 @@
 * Version 0.18.0
+- Add batch embeddings capability (currently for just Open AI and Ollama).
 - Add Microsoft Azure's Open AI
 - Remove testing and other development files from ELPA packaging.
 - Remove vendored =plz-event-source= and =plz-media-type=, and add 
requirements.
diff --git a/llm-azure.el b/llm-azure.el
index cf51224bc1..5fa0440fd5 100644
--- a/llm-azure.el
+++ b/llm-azure.el
@@ -40,7 +40,7 @@
           (llm-azure-url provider)
           (llm-azure-chat-model provider)))
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-azure))
+(cl-defmethod llm-provider-embedding-url ((provider llm-azure) &optional _)
   (format "%s/openai/deployments/%s/embeddings?api-version=2024-08-01-preview"
           (llm-azure-url provider)
           (llm-azure-embedding-model provider)))
diff --git a/llm-gemini.el b/llm-gemini.el
index 3c0b5f2a8f..5c70c406b4 100644
--- a/llm-gemini.el
+++ b/llm-gemini.el
@@ -43,7 +43,7 @@ You can get this at https://makersuite.google.com/app/apikey.";
   "Return nonfree terms of service for Gemini."
   "https://policies.google.com/terms/generative-ai";)
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-gemini))
+(cl-defmethod llm-provider-embedding-url ((provider llm-gemini) &optional _)
   "Return the URL for the EMBEDDING request for STRING from PROVIDER."
   (format 
"https://generativelanguage.googleapis.com/v1beta/models/%s:embedContent?key=%s";
           (llm-gemini-embedding-model provider)
diff --git a/llm-integration-test.el b/llm-integration-test.el
index e8fb296823..bc0563cf0c 100644
--- a/llm-integration-test.el
+++ b/llm-integration-test.el
@@ -139,11 +139,63 @@
          (ert-info ((format "Using provider %s" (llm-name provider)))
            ,@body)))))
 
+(llm-def-integration-test llm-embedding (provider)
+  (when (member 'embeddings (llm-capabilities provider))
+    (let ((result (llm-embedding provider "Paris")))
+      (should (vectorp result))
+      (should (> (length result) 0)))))
+
+(llm-def-integration-test llm-embedding-async (provider)
+  (when (member 'embeddings (llm-capabilities provider))
+    (let ((result nil)
+          (buf (current-buffer))
+          (llm-warn-on-nonfree nil))
+      (llm-embedding-async
+       provider
+       "Paris"
+       (lambda (response)
+         (should (eq (current-buffer) buf))
+         (setq result response))
+       (lambda (error)
+         (error "Error: %s" error)))
+      (while (null result)
+        (sleep-for 0.1))
+      (should (vectorp result))
+      (should (> (length result) 0)))))
+
+(llm-def-integration-test llm-batch-embeddings (provider)
+  (when (member 'embeddings-batch (llm-capabilities provider))
+    (let ((result (llm-batch-embeddings provider '("Paris" "France"))))
+      (should (listp result))
+      (should (= (length result) 2))
+      (should (vectorp (aref result 0)))
+      (should (vectorp (aref result 1))))))
+
+(llm-def-integration-test llm-batch-embedding-async (provider)
+  (when (member 'embeddings-batch (llm-capabilities provider))
+    (let ((result nil)
+          (buf (current-buffer))
+          (llm-warn-on-nonfree nil))
+      (llm-batch-embeddings-async
+       provider
+       '("Paris" "France")
+       (lambda (response)
+         (should (eq (current-buffer) buf))
+         (setq result response))
+       (lambda (error)
+         (error "Error: %s" error)))
+      (while (null result)
+        (sleep-for 0.1))
+      (should (listp result))
+      (should (= (length result) 2))
+      (should (vectorp (aref result 0)))
+      (should (vectorp (aref result 1))))))
+
 (llm-def-integration-test llm-chat (provider)
   (should (equal
-           (llm-chat
-            provider
-            (llm-make-chat-prompt llm-integration-test-chat-prompt))
+           (string-trim (llm-chat
+                         provider
+                         (llm-make-chat-prompt 
llm-integration-test-chat-prompt)))
            llm-integration-test-chat-answer)))
 
 (llm-def-integration-test llm-chat-async (provider)
@@ -161,9 +213,8 @@
        (setq err-result err)))
     (while (not (or result err-result))
       (sleep-for 0.1))
-    (if err-result
-        (error err-result))
-    (should (equal result llm-integration-test-chat-answer))))
+    (if err-result (error err-result))
+    (should (equal (string-trim result) llm-integration-test-chat-answer))))
 
 (llm-def-integration-test llm-chat-streaming (provider)
   (when (member 'streaming (llm-capabilities provider))
@@ -189,8 +240,8 @@
                   (time-less-p (time-subtract (current-time) start-time) 10))
         (sleep-for 0.1))
       (if err-result (error err-result))
-      (should (equal returned-result llm-integration-test-chat-answer))
-      (should (equal streamed-result llm-integration-test-chat-answer)))))
+      (should (equal (string-trim returned-result) 
llm-integration-test-chat-answer))
+      (should (equal (string-trim streamed-result) 
llm-integration-test-chat-answer)))))
 
 (llm-def-integration-test llm-function-call (provider)
   (when (member 'function-calls (llm-capabilities provider))
diff --git a/llm-llamacpp.el b/llm-llamacpp.el
index 3ffe8b70d3..69ad298ed1 100644
--- a/llm-llamacpp.el
+++ b/llm-llamacpp.el
@@ -60,7 +60,7 @@ PATH is the path to append to the URL, not prefixed with a 
slash."
         (port (llm-llamacpp-port provider)))
     (format "%s://%s:%d/%s" scheme host port path)))
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-llamacpp))
+(cl-defmethod llm-provider-embedding-url ((provider llm-llamacpp) &optional _)
   (llm-llamacpp--url provider "embedding"))
 
 (cl-defmethod llm-provider-chat-url ((provider llm-llamacpp))
diff --git a/llm-ollama.el b/llm-ollama.el
index 8faf58bcf3..b88666a315 100644
--- a/llm-ollama.el
+++ b/llm-ollama.el
@@ -73,7 +73,7 @@ EMBEDDING-MODEL is the model to use for embeddings.  It is 
required."
   (format "%s://%s:%d/api/%s" (llm-ollama-scheme provider )(llm-ollama-host 
provider)
           (llm-ollama-port provider) method))
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-ollama))
+(cl-defmethod llm-provider-embedding-url ((provider llm-ollama) &optional _)
   (llm-ollama--url provider "embed"))
 
 (cl-defmethod llm-provider-chat-url ((provider llm-ollama))
@@ -94,10 +94,16 @@ PROVIDER is the llm-ollama provider."
   `(("input" . ,string)
     ("model" . ,(llm-ollama-embedding-model provider))))
 
+(cl-defmethod llm-provider-batch-embeddings-request ((provider llm-ollama) 
strings)
+  (llm-provider-embedding-request provider strings))
+
 (cl-defmethod llm-provider-embedding-extract-result ((_ llm-ollama) response)
   "Return the embedding from the server RESPONSE."
   (aref (assoc-default 'embeddings response) 0))
 
+(cl-defmethod llm-provider-batch-embeddings-extract-result ((_ llm-ollama) 
response)
+  (append (assoc-default 'embeddings response) nil))
+
 (cl-defmethod llm-provider-chat-extract-result ((_ llm-ollama) response)
   "Return the chat response from the server RESPONSE."
   (assoc-default 'content (assoc-default 'message response)))
@@ -174,7 +180,7 @@ PROVIDER is the llm-ollama provider."
                                              (llm-ollama-embedding-model 
provider))))
                        (and embedding-model
                             (member 'embedding (llm-model-capabilities 
embedding-model)))))
-            '(embeddings))
+            '(embeddings embeddings-batch))
           (when (let ((chat-model (llm-models-match
                                    (llm-ollama-chat-model provider))))
                   (and chat-model
diff --git a/llm-openai.el b/llm-openai.el
index 88b0710a04..11d6ba6f7e 100644
--- a/llm-openai.el
+++ b/llm-openai.el
@@ -66,17 +66,30 @@ https://api.example.com/v1/chat, then URL should be
   "Return Open AI's nonfree terms of service."
   "https://openai.com/policies/terms-of-use";)
 
-(cl-defmethod llm-provider-embedding-request ((provider llm-openai) string)
-  "Return the request to the server for the embedding of STRING.
+(cl-defmethod llm-provider-embedding-request ((provider llm-openai) 
string-or-list)
+  "Return the request to the server for the embedding of STRING-OR-LIST.
 PROVIDER is the Open AI provider struct."
-  `(("input" . ,string)
+  `(("input" . ,string-or-list)
     ("model" . ,(or (llm-openai-embedding-model provider)
                     "text-embedding-3-small"))))
 
+(cl-defmethod llm-provider-batch-embeddings-request ((provider llm-openai) 
batch)
+  (llm-provider-embedding-request provider batch))
+
 (cl-defmethod llm-provider-embedding-extract-result ((_ llm-openai) response)
   "Return the embedding from the server RESPONSE."
   (assoc-default 'embedding (aref (assoc-default 'data response) 0)))
 
+(cl-defmethod llm-provider-batch-embeddings-extract-result ((_ llm-openai) 
response)
+  "Return the embedding from the server RESPONSE."
+  (let* ((data (assoc-default 'data response))
+         (vec (make-vector (length data) nil)))
+    (mapc (lambda (d)
+            (aset vec (assoc-default 'index d)
+                  (assoc-default 'embedding d)))
+          data)
+    (append vec nil)))
+
 (cl-defgeneric llm-openai--check-key (provider)
   "Check that the key is set for the Open AI PROVIDER.")
 
@@ -114,7 +127,7 @@ PROVIDER is the Open AI provider struct."
 (cl-defmethod llm-openai--url ((_ llm-openai) command)
   (concat "https://api.openai.com/v1/"; command))
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-openai))
+(cl-defmethod llm-provider-embedding-url ((provider llm-openai) &optional _)
   (llm-openai--url provider "embeddings"))
 
 (cl-defmethod llm-provider-chat-url ((provider llm-openai))
@@ -266,7 +279,7 @@ RESPONSE can be nil if the response is complete."
 (cl-defmethod llm-capabilities ((provider llm-openai-compatible))
   (append '(streaming)
           (when (llm-openai-embedding-model provider)
-            '(embeddings))
+            '(embeddings embeddings-batch))
           (let ((model (llm-models-match (llm-openai-chat-model provider))))
             (when (and model (member 'tool-use (llm-model-capabilities model)))
               '(function-calls)))))
diff --git a/llm-provider-utils.el b/llm-provider-utils.el
index 4013c671fd..9d3598df77 100644
--- a/llm-provider-utils.el
+++ b/llm-provider-utils.el
@@ -83,12 +83,16 @@ PROVIDER is the provider that will be used to make the 
request.")
   nil)
 
 ;; Methods for embeddings
-(cl-defgeneric llm-provider-embedding-url (provider)
-  "Return the URL for embeddings for the PROVIDER.")
+(cl-defgeneric llm-provider-embedding-url (provider &optional batch)
+  "Return the URL for embeddings for the PROVIDER.
+BATCH is true if this is a batch request.")
 
 (cl-defgeneric llm-provider-embedding-request (provider string)
   "Return the request for the PROVIDER for STRING.")
 
+(cl-defgeneric llm-provider-batch-embeddings-request (provider string-list)
+  "Return the request for the PROVIDER for STRING-LIST.")
+
 (cl-defgeneric llm-provider-embedding-extract-error (provider response)
   "Return an error message from RESPONSE for the PROVIDER.
 
@@ -103,6 +107,9 @@ Return nil if there is no error.")
 (cl-defgeneric llm-provider-embedding-extract-result (provider response)
   "Return the result from RESPONSE for the PROVIDER.")
 
+(cl-defgeneric llm-provider-batch-embeddings-extract-result (provider response)
+  "Return the result from RESPONSE for the PROVIDER for a batch request.")
+
 ;; Methods for chat
 
 (cl-defgeneric llm-provider-chat-url (provider)
@@ -219,7 +226,7 @@ return a list of `llm-chat-function-call' structs.")
 (cl-defmethod llm-embedding ((provider llm-standard-full-provider) string)
   (llm-provider-request-prelude provider)
   (let ((response (llm-request-plz-sync
-                   (llm-provider-embedding-url provider)
+                   (llm-provider-embedding-url provider nil)
                    :timeout (llm-provider-chat-timeout provider)
                    :headers (llm-provider-headers provider)
                    :data (llm-provider-embedding-request provider string))))
@@ -231,7 +238,7 @@ return a list of `llm-chat-function-call' structs.")
   (llm-provider-request-prelude provider)
   (let ((buf (current-buffer)))
     (llm-request-plz-async
-     (llm-provider-embedding-url provider)
+     (llm-provider-embedding-url provider nil)
      :headers (llm-provider-headers provider)
      :data (llm-provider-embedding-request provider string)
      :on-success (lambda (data)
@@ -251,6 +258,41 @@ return a list of `llm-chat-function-call' structs.")
                          provider data)
                         "Unknown error")))))))
 
+(cl-defmethod llm-batch-embeddings ((provider llm-standard-full-provider) 
string-list)
+  (llm-provider-request-prelude provider)
+  (let ((response (llm-request-plz-sync
+                   (llm-provider-embedding-url provider t)
+                   :timeout (llm-provider-chat-timeout provider)
+                   :headers (llm-provider-headers provider)
+                   :data (llm-provider-batch-embeddings-request provider 
string-list))))
+    (if-let ((err-msg (llm-provider-embedding-extract-error provider 
response)))
+        (error err-msg)
+      (llm-provider-batch-embeddings-extract-result provider response))))
+
+(cl-defmethod llm-batch-embeddings-async ((provider 
llm-standard-full-provider) string-list vector-callback error-callback)
+  (llm-provider-request-prelude provider t)
+  (let ((buf (current-buffer)))
+    (llm-request-plz-async
+     (llm-provider-embedding-url provider t)
+     :headers (llm-provider-headers provider)
+     :data (llm-provider-batch-embeddings-request provider string-list)
+     :on-success (lambda (data)
+                   (if-let ((err-msg (llm-provider-embedding-extract-error 
provider data)))
+                       (llm-provider-utils-callback-in-buffer
+                        buf error-callback 'error
+                        err-msg)
+                     (llm-provider-utils-callback-in-buffer
+                      buf vector-callback
+                      (llm-provider-embedding-extract-result provider data))))
+     :on-error (lambda (_ data)
+                 (llm-provider-utils-callback-in-buffer
+                  buf error-callback 'error
+                  (if (stringp data)
+                      data
+                    (or (llm-provider-embedding-extract-error
+                         provider data)
+                        "Unknown error")))))))
+
 (cl-defmethod llm-chat ((provider llm-standard-chat-provider) prompt)
   (llm-provider-request-prelude provider)
   (let ((response (llm-request-plz-sync (llm-provider-chat-url provider)
diff --git a/llm-vertex.el b/llm-vertex.el
index 185587291c..d91a8b5878 100644
--- a/llm-vertex.el
+++ b/llm-vertex.el
@@ -105,7 +105,7 @@ the key must be regenerated every hour."
       (setf (llm-vertex-key provider) (encode-coding-string result 'utf-8)))
     (setf (llm-vertex-key-gentime provider) (current-time))))
 
-(cl-defmethod llm-provider-embedding-url ((provider llm-vertex))
+(cl-defmethod llm-provider-embedding-url ((provider llm-vertex) &optional _)
   (format 
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict";
           llm-vertex-gcloud-region
           (llm-vertex-project provider)
diff --git a/llm.el b/llm.el
index 0de9c695d2..733c619c4b 100644
--- a/llm.el
+++ b/llm.el
@@ -468,6 +468,8 @@ won't have any partial responses, so basically just 
operates like
 
 `embeddings': the LLM can return vector embeddings of text.
 
+`embeddings-batch': the LLM can return many vector embeddings at the same time.
+
 `function-calls': the LLM can call functions."
   (ignore provider)
   nil)
@@ -517,6 +519,44 @@ be passed to `llm-cancel-request'."
   (when-let (info (llm-nonfree-message-info provider))
     (llm--warn-on-nonfree (llm-name provider) info)))
 
+(cl-defmethod llm-batch-embeddings (provider string-list)
+  "Return a list of embedding vectors of STRING-LIST.
+
+The list of vectors is in an order corresponding to the order of
+STRING-LIST.
+
+PROVIDER is the provider struct that will be used for an LLM call."
+  (ignore provider string-list)
+  (signal 'not-implemented nil))
+
+(cl-defmethod llm-batch-embeddings ((_ (eql nil)) _)
+  "Catch trivial configuration mistake."
+  (error "LLM provider was nil.  Please set the provider in the application 
you are using"))
+
+(cl-defmethod llm-batch-embeddings :before (provider _)
+  "Issue a warning if the LLM is non-free."
+  (when-let (info (llm-nonfree-message-info provider))
+    (llm--warn-on-nonfree (llm-name provider) info)))
+
+(cl-defmethod llm-batch-embeddings-async (provider string-list vector-callback 
error-callback)
+  "Calculate a list of vector embeddings of STRING-LIST from PROVIDER.
+
+VECTOR-CALLBACK will be called with the list of vector embeddings.
+
+ERROR-CALLBACK will be called in the event of an error, with a signal
+and a string message."
+  (ignore provider string-list vector-callback error-callback)
+  (signal 'not-implemented nil))
+
+(cl-defmethod llm-batch-embeddings-async ((_ (eql nil)) _ _ _)
+  "Catch trivial configuration mistake."
+  (error "LLM provider was nil.  Please set the provider in the application 
you are using"))
+
+(cl-defmethod llm-batch-embeddings-async :before (provider _ _ _)
+  "Issue a warning if the LLM is non-free."
+  (when-let (info (llm-nonfree-message-info provider))
+    (llm--warn-on-nonfree (llm-name provider) info)))
+
 (cl-defgeneric llm-count-tokens (provider string)
   "Return the number of tokens in STRING from PROVIDER.
 This may be an estimate if the LLM does not provide an exact



reply via email to

[Prev in Thread] Current Thread [Next in Thread]