diff options
Diffstat (limited to 'autogpt/memory/vector/utils.py')
-rw-r--r-- | autogpt/memory/vector/utils.py | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index 75d1f69d4..beb2fcf93 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -1,16 +1,14 @@ from typing import Any, overload import numpy as np -import openai from autogpt.config import Config -from autogpt.llm.utils import metered, retry_openai_api +from autogpt.llm.base import TText +from autogpt.llm.providers import openai as iopenai from autogpt.logs import logger Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]] """Embedding vector""" -TText = list[int] -"""Token array representing text""" @overload @@ -23,10 +21,8 @@ def get_embedding(input: list[str] | list[TText]) -> list[Embedding]: ... -@metered -@retry_openai_api() def get_embedding( - input: str | TText | list[str] | list[TText], + input: str | TText | list[str] | list[TText], config: Config ) -> Embedding | list[Embedding]: """Get an embedding from the ada model. @@ -37,7 +33,6 @@ def get_embedding( Returns: List[float]: The embedding. """ - cfg = Config() multiple = isinstance(input, list) and all(not isinstance(i, int) for i in input) if isinstance(input, str): @@ -45,22 +40,22 @@ def get_embedding( elif multiple and isinstance(input[0], str): input = [text.replace("\n", " ") for text in input] - model = cfg.embedding_model - if cfg.use_azure: - kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)} + model = config.embedding_model + if config.use_azure: + kwargs = {"engine": config.get_azure_deployment_id_for_model(model)} else: kwargs = {"model": model} logger.debug( f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}" f" with model '{model}'" - + (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "") + + (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "") ) - embeddings = openai.Embedding.create( - input=input, - api_key=cfg.openai_api_key, + embeddings = iopenai.create_embedding( + input, **kwargs, + api_key=config.openai_api_key, ).data if not multiple: |