aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Krzysztof Czerwinski <34861343+kcze@users.noreply.github.com> 2024-03-21 18:11:36 +0100
committerGravatar GitHub <noreply@github.com> 2024-03-21 18:11:36 +0100
commit76d6e61941e128317e3171de820bc8dc719bd082 (patch)
tree71282e8cc46ab4e1caa6c89bfa0e55473fb9e564
parentfix(agent): Add check for Linux container support to `is_docker_available` (diff)
downloadAuto-GPT-76d6e61941e128317e3171de820bc8dc719bd082.tar.gz
Auto-GPT-76d6e61941e128317e3171de820bc8dc719bd082.tar.bz2
Auto-GPT-76d6e61941e128317e3171de820bc8dc719bd082.zip
feat(agent): Implement more fault tolerant `json_loads` function (#7016)
* Implement syntax fault tolerant `json_loads` function using `dem3json` - Add `dem3json` dependency * Replace `json.loads` by `json_loads` in places where malformed JSON may occur * Move `json_utils.py` to `autogpt/core/utils` * Add tests for `json_utils` --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
-rw-r--r--autogpts/autogpt/autogpt/agent_factory/profile_generator.py2
-rw-r--r--autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py6
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py3
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py2
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py3
-rw-r--r--autogpts/autogpt/autogpt/core/prompting/utils.py20
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py6
-rw-r--r--autogpts/autogpt/autogpt/core/utils/json_utils.py92
-rw-r--r--autogpts/autogpt/autogpt/json_utils/__init__.py0
-rw-r--r--autogpts/autogpt/autogpt/json_utils/utilities.py55
-rw-r--r--autogpts/autogpt/autogpt/processing/text.py7
-rw-r--r--autogpts/autogpt/poetry.lock12
-rw-r--r--autogpts/autogpt/pyproject.toml1
-rw-r--r--autogpts/autogpt/tests/unit/test_json_utils.py93
-rw-r--r--autogpts/autogpt/tests/unit/test_utils.py18
15 files changed, 217 insertions, 103 deletions
diff --git a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
index ea4a602e4..889b7f2d4 100644
--- a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
+++ b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
@@ -8,7 +8,6 @@ from autogpt.core.prompting import (
LanguageModelClassification,
PromptStrategy,
)
-from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
ChatMessage,
@@ -16,6 +15,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
diff --git a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
index 243d07f4e..72916e1fa 100644
--- a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
+++ b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.json_utils.utilities import extract_dict_from_response
+from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
from autogpt.prompts.utils import format_numbered_list, indent
@@ -386,7 +386,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
else f" '{response.content}'"
)
)
- assistant_reply_dict = extract_dict_from_response(response.content)
+ assistant_reply_dict = extract_dict_from_json(response.content)
self.logger.debug(
"Validating object extracted from LLM response:\n"
f"{json.dumps(assistant_reply_dict, indent=4)}"
@@ -439,7 +439,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name,
- "args": json.loads(assistant_reply.tool_calls[0].function.arguments),
+ "args": json_loads(assistant_reply.tool_calls[0].function.arguments),
}
try:
if not isinstance(assistant_reply_json, dict):
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
index 9912fb6e2..d26d86fd6 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.schema import Task, TaskType
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
-from autogpt.core.prompting.utils import json_loads, to_numbered_list
+from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
index 907a9717b..d030c05e1 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
@@ -3,13 +3,13 @@ import logging
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
-from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
index 9672f8e5b..dec67c295 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.schema import Task
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
-from autogpt.core.prompting.utils import json_loads, to_numbered_list
+from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
diff --git a/autogpts/autogpt/autogpt/core/prompting/utils.py b/autogpts/autogpt/autogpt/core/prompting/utils.py
index 4b1be47f4..865b3fc08 100644
--- a/autogpts/autogpt/autogpt/core/prompting/utils.py
+++ b/autogpts/autogpt/autogpt/core/prompting/utils.py
@@ -1,7 +1,3 @@
-import ast
-import json
-
-
def to_numbered_list(
items: list[str], no_items_response: str = "", **template_args
) -> str:
@@ -11,19 +7,3 @@ def to_numbered_list(
)
else:
return no_items_response
-
-
-def json_loads(json_str: str):
- # TODO: this is a hack function for now. We'll see what errors show up in testing.
- # Can hopefully just replace with a call to ast.literal_eval.
- # Can't use json.loads because the function API still sometimes returns json strings
- # with minor issues like trailing commas.
- try:
- json_str = json_str[json_str.index("{") : json_str.rindex("}") + 1]
- return ast.literal_eval(json_str)
- except json.decoder.JSONDecodeError as e:
- try:
- print(f"json decode error {e}. trying literal eval")
- return ast.literal_eval(json_str)
- except Exception:
- breakpoint()
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 1b564b6df..dfaa4ff03 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -38,6 +38,7 @@ from autogpt.core.resource.model_providers.schema import (
ModelTokenizer,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
_T = TypeVar("_T")
_P = ParamSpec("_P")
@@ -758,19 +759,18 @@ def _functions_compat_fix_kwargs(
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
- import json
import re
import uuid
logging.debug(f"Trying to extract tool calls from response:\n{response}")
if response[0] == "[":
- tool_calls: list[AssistantToolCallDict] = json.loads(response)
+ tool_calls: list[AssistantToolCallDict] = json_loads(response)
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find tool_calls block in response")
- tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
+ tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
for t in tool_calls:
t["id"] = str(uuid.uuid4())
diff --git a/autogpts/autogpt/autogpt/core/utils/json_utils.py b/autogpts/autogpt/autogpt/core/utils/json_utils.py
new file mode 100644
index 000000000..664cb87c1
--- /dev/null
+++ b/autogpts/autogpt/autogpt/core/utils/json_utils.py
@@ -0,0 +1,92 @@
+import io
+import logging
+import re
+from typing import Any
+
+import demjson3
+
+logger = logging.getLogger(__name__)
+
+
+def json_loads(json_str: str) -> Any:
+ """Parse a JSON string, tolerating minor syntax issues:
+ - Missing, extra and trailing commas
+ - Extraneous newlines and whitespace outside of string literals
+ - Inconsistent spacing after colons and commas
+ - Missing closing brackets or braces
+ - Numbers: binary, hex, octal, trailing and prefixed decimal points
+ - Different encodings
+ - Surrounding markdown code block
+ - Comments
+
+ Args:
+ json_str: The JSON string to parse.
+
+ Returns:
+ The parsed JSON object, same as built-in json.loads.
+ """
+ # Remove possible code block
+ pattern = r"```(?:json|JSON)*([\s\S]*?)```"
+ match = re.search(pattern, json_str)
+
+ if match:
+ json_str = match.group(1).strip()
+
+ error_buffer = io.StringIO()
+ json_result = demjson3.decode(
+ json_str, return_errors=True, write_errors=error_buffer
+ )
+
+ if error_buffer.getvalue():
+ logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
+
+ if json_result is None:
+ raise ValueError(f"Failed to parse JSON string: {json_str}")
+
+ return json_result.object
+
+
+def extract_dict_from_json(json_str: str) -> dict[str, Any]:
+ # Sometimes the response includes the JSON in a code block with ```
+ pattern = r"```(?:json|JSON)*([\s\S]*?)```"
+ match = re.search(pattern, json_str)
+
+ if match:
+ json_str = match.group(1).strip()
+ else:
+ # The string may contain JSON.
+ json_pattern = r"{[\s\S]*}"
+ match = re.search(json_pattern, json_str)
+
+ if match:
+ json_str = match.group()
+
+ result = json_loads(json_str)
+ if not isinstance(result, dict):
+ raise ValueError(
+ f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
+ )
+ return result
+
+
+def extract_list_from_json(json_str: str) -> list[Any]:
+ # Sometimes the response includes the JSON in a code block with ```
+ pattern = r"```(?:json|JSON)*([\s\S]*?)```"
+ match = re.search(pattern, json_str)
+
+ if match:
+ json_str = match.group(1).strip()
+ else:
+ # The string may contain JSON.
+ json_pattern = r"\[[\s\S]*\]"
+ match = re.search(json_pattern, json_str)
+
+ if match:
+ json_str = match.group()
+
+ result = json_loads(json_str)
+ if not isinstance(result, list):
+ raise ValueError(
+ f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
+ )
+ return result
diff --git a/autogpts/autogpt/autogpt/json_utils/__init__.py b/autogpts/autogpt/autogpt/json_utils/__init__.py
deleted file mode 100644
index e69de29bb..000000000
--- a/autogpts/autogpt/autogpt/json_utils/__init__.py
+++ /dev/null
diff --git a/autogpts/autogpt/autogpt/json_utils/utilities.py b/autogpts/autogpt/autogpt/json_utils/utilities.py
deleted file mode 100644
index c81ff2271..000000000
--- a/autogpts/autogpt/autogpt/json_utils/utilities.py
+++ /dev/null
@@ -1,55 +0,0 @@
-"""Utilities for the json_fixes package."""
-import json
-import logging
-import re
-from typing import Any
-
-logger = logging.getLogger(__name__)
-
-
-def extract_dict_from_response(response_content: str) -> dict[str, Any]:
- # Sometimes the response includes the JSON in a code block with ```
- pattern = r"```(?:json|JSON)*([\s\S]*?)```"
- match = re.search(pattern, response_content)
-
- if match:
- response_content = match.group(1).strip()
- else:
- # The string may contain JSON.
- json_pattern = r"{[\s\S]*}"
- match = re.search(json_pattern, response_content)
-
- if match:
- response_content = match.group()
-
- result = json.loads(response_content)
- if not isinstance(result, dict):
- raise ValueError(
- f"Response '''{response_content}''' evaluated to "
- f"non-dict value {repr(result)}"
- )
- return result
-
-
-def extract_list_from_response(response_content: str) -> list[Any]:
- # Sometimes the response includes the JSON in a code block with ```
- pattern = r"```(?:json|JSON)*([\s\S]*?)```"
- match = re.search(pattern, response_content)
-
- if match:
- response_content = match.group(1).strip()
- else:
- # The string may contain JSON.
- json_pattern = r"\[[\s\S]*\]"
- match = re.search(json_pattern, response_content)
-
- if match:
- response_content = match.group()
-
- result = json.loads(response_content)
- if not isinstance(result, list):
- raise ValueError(
- f"Response '''{response_content}''' evaluated to "
- f"non-list value {repr(result)}"
- )
- return result
diff --git a/autogpts/autogpt/autogpt/processing/text.py b/autogpts/autogpt/autogpt/processing/text.py
index de0e1bf3f..8e5c0794b 100644
--- a/autogpts/autogpt/autogpt/processing/text.py
+++ b/autogpts/autogpt/autogpt/processing/text.py
@@ -1,4 +1,5 @@
"""Text processing functions"""
+
import logging
import math
from typing import Iterator, Optional, TypeVar
@@ -12,7 +13,7 @@ from autogpt.core.resource.model_providers import (
ChatModelProvider,
ModelTokenizer,
)
-from autogpt.json_utils.utilities import extract_list_from_response
+from autogpt.core.utils.json_utils import extract_list_from_json
logger = logging.getLogger(__name__)
@@ -161,9 +162,7 @@ async def _process_text(
temperature=0.5,
max_tokens=max_result_tokens,
completion_parser=lambda s: (
- extract_list_from_response(s.content)
- if output_type is not str
- else None
+ extract_list_from_json(s.content) if output_type is not str else None
),
)
diff --git a/autogpts/autogpt/poetry.lock b/autogpts/autogpt/poetry.lock
index 6ee147a45..1c6ab48e1 100644
--- a/autogpts/autogpt/poetry.lock
+++ b/autogpts/autogpt/poetry.lock
@@ -1617,6 +1617,16 @@ files = [
]
[[package]]
+name = "demjson3"
+version = "3.0.6"
+description = "encoder, decoder, and lint/validator for JSON (JavaScript Object Notation) compliant with RFC 7159"
+optional = false
+python-versions = "*"
+files = [
+ {file = "demjson3-3.0.6.tar.gz", hash = "sha256:37c83b0c6eb08d25defc88df0a2a4875d58a7809a9650bd6eee7afd8053cdbac"},
+]
+
+[[package]]
name = "deprecated"
version = "1.2.14"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
@@ -7248,4 +7258,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "a09e20daaf94e05457ded6a9989585cd37edf96a036c5f9e505f0b2456403a25"
+content-hash = "9e28a0449253ec931297aa655fcd09da5e9f5d57bd73863419ce4f477018ef8a"
diff --git a/autogpts/autogpt/pyproject.toml b/autogpts/autogpt/pyproject.toml
index ad138ee24..73706967a 100644
--- a/autogpts/autogpt/pyproject.toml
+++ b/autogpts/autogpt/pyproject.toml
@@ -30,6 +30,7 @@ boto3 = "^1.33.6"
charset-normalizer = "^3.1.0"
click = "*"
colorama = "^0.4.6"
+demjson3 = "^3.0.0"
distro = "^1.8.0"
docker = "*"
duckduckgo-search = "^4.0.0"
diff --git a/autogpts/autogpt/tests/unit/test_json_utils.py b/autogpts/autogpt/tests/unit/test_json_utils.py
new file mode 100644
index 000000000..fdd1b0f08
--- /dev/null
+++ b/autogpts/autogpt/tests/unit/test_json_utils.py
@@ -0,0 +1,93 @@
+import json
+
+import pytest
+
+from autogpt.core.utils.json_utils import json_loads
+
+_JSON_FIXABLE: list[tuple[str, str]] = [
+ # Missing comma
+ ('{"name": "John Doe" "age": 30,}', '{"name": "John Doe", "age": 30}'),
+ ("[1, 2 3]", "[1, 2, 3]"),
+ # Trailing comma
+ ('{"name": "John Doe", "age": 30,}', '{"name": "John Doe", "age": 30}'),
+ ("[1, 2, 3,]", "[1, 2, 3]"),
+ # Extra comma in object
+ ('{"name": "John Doe",, "age": 30}', '{"name": "John Doe", "age": 30}'),
+ # Extra newlines
+ ('{"name": "John Doe",\n"age": 30}', '{"name": "John Doe", "age": 30}'),
+ ("[1, 2,\n3]", "[1, 2, 3]"),
+ # Missing closing brace or bracket
+ ('{"name": "John Doe", "age": 30', '{"name": "John Doe", "age": 30}'),
+ ("[1, 2, 3", "[1, 2, 3]"),
+ # Different numerals
+ ("[+1, ---2, .5, +-4.5, 123.]", "[1, -2, 0.5, -4.5, 123]"),
+ ('{"bin": 0b1001, "hex": 0x1A, "oct": 0o17}', '{"bin": 9, "hex": 26, "oct": 15}'),
+ # Broken array
+ (
+ '[1, 2 3, "yes" true, false null, 25, {"obj": "var"}',
+ '[1, 2, 3, "yes", true, false, null, 25, {"obj": "var"}]',
+ ),
+ # Codeblock
+ (
+ '```json\n{"name": "John Doe", "age": 30}\n```',
+ '{"name": "John Doe", "age": 30}',
+ ),
+ # Mutliple problems
+ (
+ '{"name":"John Doe" "age": 30\n "empty": "","address": '
+ "// random comment\n"
+ '{"city": "New York", "state": "NY"},'
+ '"skills": ["Python" "C++", "Java",""],',
+ '{"name": "John Doe", "age": 30, "empty": "", "address": '
+ '{"city": "New York", "state": "NY"}, '
+ '"skills": ["Python", "C++", "Java", ""]}',
+ ),
+ # All good
+ (
+ '{"name": "John Doe", "age": 30, "address": '
+ '{"city": "New York", "state": "NY"}, '
+ '"skills": ["Python", "C++", "Java"]}',
+ '{"name": "John Doe", "age": 30, "address": '
+ '{"city": "New York", "state": "NY"}, '
+ '"skills": ["Python", "C++", "Java"]}',
+ ),
+ ("true", "true"),
+ ("false", "false"),
+ ("null", "null"),
+ ("123.5", "123.5"),
+ ('"Hello, World!"', '"Hello, World!"'),
+ ("{}", "{}"),
+ ("[]", "[]"),
+]
+
+_JSON_UNFIXABLE: list[tuple[str, str]] = [
+ # Broken booleans and null
+ ("[TRUE, False, NULL]", "[true, false, null]"),
+ # Missing values in array
+ ("[1, , 3]", "[1, 3]"),
+ # Leading zeros (are treated as octal)
+ ("[0023, 015]", "[23, 15]"),
+ # Missing quotes
+ ('{"name": John Doe}', '{"name": "John Doe"}'),
+ # Missing opening braces or bracket
+ ('"name": "John Doe"}', '{"name": "John Doe"}'),
+ ("1, 2, 3]", "[1, 2, 3]"),
+]
+
+
+@pytest.fixture(params=_JSON_FIXABLE)
+def fixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
+ return request.param
+
+
+@pytest.fixture(params=_JSON_UNFIXABLE)
+def unfixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
+ return request.param
+
+
+def test_json_loads_fixable(fixable_json: tuple[str, str]):
+ assert json_loads(fixable_json[0]) == json.loads(fixable_json[1])
+
+
+def test_json_loads_unfixable(unfixable_json: tuple[str, str]):
+ assert json_loads(unfixable_json[0]) != json.loads(unfixable_json[1])
diff --git a/autogpts/autogpt/tests/unit/test_utils.py b/autogpts/autogpt/tests/unit/test_utils.py
index 39f7586b8..9224f7212 100644
--- a/autogpts/autogpt/tests/unit/test_utils.py
+++ b/autogpts/autogpt/tests/unit/test_utils.py
@@ -14,7 +14,7 @@ from autogpt.app.utils import (
get_latest_bulletin,
set_env_config_value,
)
-from autogpt.json_utils.utilities import extract_dict_from_response
+from autogpt.core.utils.json_utils import extract_dict_from_json
from autogpt.utils import validate_yaml_file
from tests.utils import skip_in_ci
@@ -199,34 +199,26 @@ def test_get_current_git_branch_failure(mock_repo):
def test_extract_json_from_response(valid_json_response: dict):
emulated_response_from_openai = json.dumps(valid_json_response)
- assert (
- extract_dict_from_response(emulated_response_from_openai) == valid_json_response
- )
+ assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: dict):
emulated_response_from_openai = "```" + json.dumps(valid_json_response) + "```"
- assert (
- extract_dict_from_response(emulated_response_from_openai) == valid_json_response
- )
+ assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_wrapped_in_code_block_with_language(
valid_json_response: dict,
):
emulated_response_from_openai = "```json" + json.dumps(valid_json_response) + "```"
- assert (
- extract_dict_from_response(emulated_response_from_openai) == valid_json_response
- )
+ assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_json_contained_in_string(valid_json_response: dict):
emulated_response_from_openai = (
"sentence1" + json.dumps(valid_json_response) + "sentence2"
)
- assert (
- extract_dict_from_response(emulated_response_from_openai) == valid_json_response
- )
+ assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
@pytest.fixture