aboutsummaryrefslogtreecommitdiff
path: root/tests/unit/test_retry_provider_openai.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/test_retry_provider_openai.py')
-rw-r--r--tests/unit/test_retry_provider_openai.py130
1 files changed, 0 insertions, 130 deletions
diff --git a/tests/unit/test_retry_provider_openai.py b/tests/unit/test_retry_provider_openai.py
deleted file mode 100644
index 1b23f5d26..000000000
--- a/tests/unit/test_retry_provider_openai.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import pytest
-from openai.error import APIError, RateLimitError, ServiceUnavailableError
-
-from autogpt.llm.providers import openai
-
-
-@pytest.fixture(params=[RateLimitError, ServiceUnavailableError, APIError])
-def error(request):
- if request.param == APIError:
- return request.param("Error", http_status=502)
- else:
- return request.param("Error")
-
-
-def error_factory(error_instance, error_count, retry_count, warn_user=True):
- """Creates errors"""
-
- class RaisesError:
- def __init__(self):
- self.count = 0
-
- @openai.retry_api(
- max_retries=retry_count, backoff_base=0.001, warn_user=warn_user
- )
- def __call__(self):
- self.count += 1
- if self.count <= error_count:
- raise error_instance
- return self.count
-
- return RaisesError()
-
-
-def test_retry_open_api_no_error(capsys):
- """Tests the retry functionality with no errors expected"""
-
- @openai.retry_api()
- def f():
- return 1
-
- result = f()
- assert result == 1
-
- output = capsys.readouterr()
- assert output.out == ""
- assert output.err == ""
-
-
-@pytest.mark.parametrize(
- "error_count, retry_count, failure",
- [(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
- ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
-)
-def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
- """Tests the retry with simulated errors [RateLimitError, ServiceUnavailableError, APIError], but should ulimately pass"""
- call_count = min(error_count, retry_count) + 1
-
- raises = error_factory(error, error_count, retry_count)
- if failure:
- with pytest.raises(type(error)):
- raises()
- else:
- result = raises()
- assert result == call_count
-
- assert raises.count == call_count
-
- output = capsys.readouterr()
-
- if error_count and retry_count:
- if type(error) == RateLimitError:
- assert "Reached rate limit" in output.out
- assert "Please double check" in output.out
- if type(error) == ServiceUnavailableError:
- assert "The OpenAI API engine is currently overloaded" in output.out
- assert "Please double check" in output.out
- else:
- assert output.out == ""
-
-
-def test_retry_open_api_rate_limit_no_warn(capsys):
- """Tests the retry logic with a rate limit error"""
- error_count = 2
- retry_count = 10
-
- raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
- result = raises()
- call_count = min(error_count, retry_count) + 1
- assert result == call_count
- assert raises.count == call_count
-
- output = capsys.readouterr()
-
- assert "Reached rate limit" in output.out
- assert "Please double check" not in output.out
-
-
-def test_retry_open_api_service_unavairable_no_warn(capsys):
- """Tests the retry logic with a service unavairable error"""
- error_count = 2
- retry_count = 10
-
- raises = error_factory(
- ServiceUnavailableError, error_count, retry_count, warn_user=False
- )
- result = raises()
- call_count = min(error_count, retry_count) + 1
- assert result == call_count
- assert raises.count == call_count
-
- output = capsys.readouterr()
-
- assert "The OpenAI API engine is currently overloaded" in output.out
- assert "Please double check" not in output.out
-
-
-def test_retry_openapi_other_api_error(capsys):
- """Tests the Retry logic with a non rate limit error such as HTTP500"""
- error_count = 2
- retry_count = 10
-
- raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
-
- with pytest.raises(APIError):
- raises()
- call_count = 1
- assert raises.count == call_count
-
- output = capsys.readouterr()
- assert output.out == ""