diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/app/agent_protocol_server.py')
-rw-r--r-- | autogpts/autogpt/autogpt/app/agent_protocol_server.py | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/autogpts/autogpt/autogpt/app/agent_protocol_server.py b/autogpts/autogpt/autogpt/app/agent_protocol_server.py index cdaf1f460..2eb09706e 100644 --- a/autogpts/autogpt/autogpt/app/agent_protocol_server.py +++ b/autogpts/autogpt/autogpt/app/agent_protocol_server.py @@ -34,7 +34,6 @@ from autogpt.agent_manager import AgentManager from autogpt.app.utils import is_port_free from autogpt.config import Config from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget -from autogpt.core.resource.model_providers.openai import OpenAIProvider from autogpt.file_storage import FileStorage from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult from autogpt.utils.exceptions import AgentFinished @@ -464,20 +463,18 @@ class AgentProtocolServer: if task.additional_input and (user_id := task.additional_input.get("user_id")): _extra_request_headers["AutoGPT-UserID"] = user_id - task_llm_provider = None - if isinstance(self.llm_provider, OpenAIProvider): - settings = self.llm_provider._settings.copy() - settings.budget = task_llm_budget - settings.configuration = task_llm_provider_config # type: ignore - task_llm_provider = OpenAIProvider( - settings=settings, - logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"), - ) - - if task_llm_provider and task_llm_provider._budget: - self._task_budgets[task.task_id] = task_llm_provider._budget + settings = self.llm_provider._settings.copy() + settings.budget = task_llm_budget + settings.configuration = task_llm_provider_config + task_llm_provider = self.llm_provider.__class__( + settings=settings, + logger=logger.getChild( + f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}" + ), + ) + self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore - return task_llm_provider or self.llm_provider + return task_llm_provider def task_agent_id(task_id: str | int) -> str: |