aboutsummaryrefslogtreecommitdiff
path: root/autogpt/memory/vector/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpt/memory/vector/utils.py')
-rw-r--r--autogpt/memory/vector/utils.py25
1 files changed, 10 insertions, 15 deletions
diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py
index 75d1f69d4..beb2fcf93 100644
--- a/autogpt/memory/vector/utils.py
+++ b/autogpt/memory/vector/utils.py
@@ -1,16 +1,14 @@
from typing import Any, overload
import numpy as np
-import openai
from autogpt.config import Config
-from autogpt.llm.utils import metered, retry_openai_api
+from autogpt.llm.base import TText
+from autogpt.llm.providers import openai as iopenai
from autogpt.logs import logger
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
"""Embedding vector"""
-TText = list[int]
-"""Token array representing text"""
@overload
@@ -23,10 +21,8 @@ def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
...
-@metered
-@retry_openai_api()
def get_embedding(
- input: str | TText | list[str] | list[TText],
+ input: str | TText | list[str] | list[TText], config: Config
) -> Embedding | list[Embedding]:
"""Get an embedding from the ada model.
@@ -37,7 +33,6 @@ def get_embedding(
Returns:
List[float]: The embedding.
"""
- cfg = Config()
multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input)
if isinstance(input, str):
@@ -45,22 +40,22 @@ def get_embedding(
elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]
- model = cfg.embedding_model
- if cfg.use_azure:
- kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)}
+ model = config.embedding_model
+ if config.use_azure:
+ kwargs = {"engine": config.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
f" with model '{model}'"
- + (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
+ + (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
)
- embeddings = openai.Embedding.create(
- input=input,
- api_key=cfg.openai_api_key,
+ embeddings = iopenai.create_embedding(
+ input,
**kwargs,
+ api_key=config.openai_api_key,
).data
if not multiple: