aboutsummaryrefslogtreecommitdiff
path: root/autogpt/core/ability/builtins/query_language_model.py
blob: 95a5e09488f5368363b3b5325c65bbf2151ee8b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging

from autogpt.core.ability.base import Ability, AbilityConfiguration
from autogpt.core.ability.schema import AbilityResult
from autogpt.core.planning.simple import LanguageModelConfiguration
from autogpt.core.plugin.simple import PluginLocation, PluginStorageFormat
from autogpt.core.resource.model_providers import (
    LanguageModelMessage,
    LanguageModelProvider,
    MessageRole,
    ModelProviderName,
    OpenAIModelName,
)


class QueryLanguageModel(Ability):
    default_configuration = AbilityConfiguration(
        location=PluginLocation(
            storage_format=PluginStorageFormat.INSTALLED_PACKAGE,
            storage_route="autogpt.core.ability.builtins.QueryLanguageModel",
        ),
        language_model_required=LanguageModelConfiguration(
            model_name=OpenAIModelName.GPT3,
            provider_name=ModelProviderName.OPENAI,
            temperature=0.9,
        ),
    )

    def __init__(
        self,
        logger: logging.Logger,
        configuration: AbilityConfiguration,
        language_model_provider: LanguageModelProvider,
    ):
        self._logger = logger
        self._configuration = configuration
        self._language_model_provider = language_model_provider

    @classmethod
    def description(cls) -> str:
        return "Query a language model. A query should be a question and any relevant context."

    @classmethod
    def arguments(cls) -> dict:
        return {
            "query": {
                "type": "string",
                "description": "A query for a language model. A query should contain a question and any relevant context.",
            },
        }

    @classmethod
    def required_arguments(cls) -> list[str]:
        return ["query"]

    async def __call__(self, query: str) -> AbilityResult:
        messages = [
            LanguageModelMessage(
                content=query,
                role=MessageRole.USER,
            ),
        ]
        model_response = await self._language_model_provider.create_language_completion(
            model_prompt=messages,
            functions=[],
            model_name=self._configuration.language_model_required.model_name,
            completion_parser=self._parse_response,
        )
        return AbilityResult(
            ability_name=self.name(),
            ability_args={"query": query},
            success=True,
            message=model_response.content["content"],
        )

    @staticmethod
    def _parse_response(response_content: dict) -> dict:
        return {"content": response_content["content"]}