diff --git a/src/openai/__init__.py b/src/openai/__init__.py index b944fbed5e..3391e8fa8c 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -117,7 +117,7 @@ from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES -api_key: str | None = None +api_key: str | _t.Callable[[], str] | None = None organization: str | None = None @@ -156,16 +156,27 @@ class _ModuleClient(OpenAI): @property # type: ignore @override - def api_key(self) -> str | None: - return api_key + def api_key(self) -> str | _t.Callable[[], str] | None: + return api_key() if callable(api_key) else api_key @api_key.setter # type: ignore - def api_key(self, value: str | None) -> None: # type: ignore + def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore global api_key + api_key = value + + @property + def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore + return None + @_api_key_provider.setter + def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore + global api_key + # Yes, setting the api_key is intentional. The module level client accepts callables + # for the module level api_key and will call it to retrieve the value + # if it is a callable. api_key = value - @property # type: ignore + @property @override def organization(self) -> str | None: return organization diff --git a/src/openai/_client.py b/src/openai/_client.py index b99db786a7..50b0c99181 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -3,11 +3,13 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable from typing_extensions import Self, override import httpx +from openai._models import FinalRequestOptions + from . import _exceptions from ._qs import Querystring from ._types import ( @@ -94,7 +96,7 @@ class OpenAI(SyncAPIClient): def __init__( self, *, - api_key: str | None = None, + api_key: str | None | Callable[[], str] = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -132,7 +134,12 @@ def __init__( raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + if callable(api_key): + self.api_key = "" + self._api_key_provider: Callable[[], str] | None = api_key + else: + self.api_key = api_key or "" + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -287,6 +294,15 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = self._api_key_provider() + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self._refresh_api_key() + return super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: @@ -310,7 +326,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -348,7 +364,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -419,7 +435,7 @@ class AsyncOpenAI(AsyncAPIClient): def __init__( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -457,7 +473,12 @@ def __init__( raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + if callable(api_key): + self.api_key = "" + self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key + else: + self.api_key = api_key or "" + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -612,6 +633,15 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + async def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = await self._api_key_provider() + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self._refresh_api_key() + return await super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: @@ -635,7 +665,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -673,7 +703,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a994e4256c..d6143f916f 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -258,7 +258,7 @@ def __init__( def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -435,7 +435,7 @@ def __init__( azure_endpoint: str | None = None, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -539,7 +539,7 @@ def __init__( def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key != "": + if self.api_key and self.api_key != "": auth_headers = {"api-key": self.api_key} else: token = await self._get_azure_ad_token() diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 7b99c7f6c4..4fa35963b6 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) diff --git a/tests/test_client.py b/tests/test_client.py index ccda50a7f0..c6435b70c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, Protocol, cast from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -41,6 +41,10 @@ api_key = "My API Key" +class MockRequestCall(Protocol): + request: httpx.Request + + def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) @@ -337,7 +341,9 @@ def test_default_headers_option(self) -> None: def test_validate_headers(self) -> None: client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) + assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -939,6 +945,63 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + def test_api_key_before_after_refresh_provider(self) -> None: + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token") + + assert client.api_key == "" + assert 'Authorization' not in client.auth_headers + + client._refresh_api_key() + + assert client.api_key == "test_bearer_token" + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + + def test_api_key_before_after_refresh_str(self) -> None: + client = OpenAI(base_url=base_url, api_key="test_api_key") + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + client._refresh_api_key() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.respx() + def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = OpenAI(base_url=base_url, api_key=token_provider) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_copy_auth(self) -> None: + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy( + api_key=lambda: "test_bearer_token_2" + ) + client._refresh_api_key() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1220,9 +1283,10 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: + async def test_validate_headers(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -1887,3 +1951,71 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + @pytest.mark.asyncio + async def test_api_key_before_after_refresh_provider(self) -> None: + async def mock_api_key_provider(): + return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider) + + assert client.api_key == "" + assert 'Authorization' not in client.auth_headers + + await client._refresh_api_key() + + assert client.api_key == "test_bearer_token" + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + + @pytest.mark.asyncio + async def test_api_key_before_after_refresh_str(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + await client._refresh_api_key() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.asyncio + @pytest.mark.respx() + async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncOpenAI(base_url=base_url, api_key=token_provider) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + @pytest.mark.asyncio + async def test_copy_auth(self) -> None: + async def token_provider_1() -> str: + return "test_bearer_token_1" + + async def token_provider_2() -> str: + return "test_bearer_token_2" + + client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2) + await client._refresh_api_key() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 9c9a1addab..643ea85c2e 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -97,6 +97,23 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client +def test_api_key_callable() -> None: + openai.api_key = lambda: "1" + assert openai.completions._client.api_key == "1" + +def test_api_key_overridable() -> None: + openai.api_key = lambda: "1" + assert openai.completions._client.api_key == "1" + assert openai.completions._client._api_key_provider is None + + openai.api_key = "2" + assert openai.completions._client.api_key == "2" + assert openai.completions._client._api_key_provider is None + + openai.api_key = lambda: "3" + assert openai.completions._client.api_key == "3" + assert openai.completions._client._api_key_provider is None + import contextlib from typing import Iterator @@ -123,6 +140,24 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" +def test_only_api_key_in_openai_api() -> None: + with fresh_env(): + openai.api_type = None + openai.api_key = lambda: "example bearer token" + + assert type(openai.completions._client).__name__ == "_ModuleClient" + + +def test_both_api_key_and_api_key_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_key = lambda: "example bearer token" + + assert openai.api_key() == "example bearer token" + + openai.api_key = "example API key" + assert openai.api_key == "example API key" + + def test_azure_api_key_env_without_api_version() -> None: with fresh_env(): openai.api_type = None