From 1b1d99a03ab8ade62f9bea6ec63ffbfc727616eb Mon Sep 17 00:00:00 2001 From: kristapratico Date: Fri, 15 Aug 2025 00:37:58 +0000 Subject: [PATCH 1/6] token provider, support custom auth headers --- src/openai/__init__.py | 27 ++- src/openai/_client.py | 87 ++++++++-- src/openai/lib/azure.py | 8 +- .../resources/beta/realtime/realtime.py | 2 + tests/test_client.py | 160 +++++++++++++++++- tests/test_module_client.py | 44 +++++ 6 files changed, 302 insertions(+), 26 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 226fed9554..9b8a29cc4d 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -9,7 +9,18 @@ from . import types from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes from ._utils import file_from_path -from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions +from ._client import ( + Client, + OpenAI, + Stream, + Timeout, + Transport, + AsyncClient, + AsyncOpenAI, + AsyncStream, + TokenProvider, + RequestOptions, +) from ._models import BaseModel from ._version import __title__, __version__ from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse @@ -119,6 +130,8 @@ api_key: str | None = None +token_provider: TokenProvider | None = None + organization: str | None = None project: str | None = None @@ -165,6 +178,17 @@ def api_key(self, value: str | None) -> None: # type: ignore api_key = value + @property # type: ignore + @override + def token_provider(self) -> TokenProvider | None: + return token_provider + + @token_provider.setter # type: ignore + def token_provider(self, value: TokenProvider | None) -> None: # type: ignore + global token_provider + + token_provider = value + @property # type: ignore @override def organization(self) -> str | None: @@ -348,6 +372,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _ModuleClient( api_key=api_key, + token_provider=token_provider, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index ed9b46f4b0..40e16f957a 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable from typing_extensions import Self, override import httpx @@ -25,6 +25,7 @@ get_async_library, ) from ._compat import cached_property +from ._models import FinalRequestOptions from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import OpenAIError, APIStatusError @@ -73,6 +74,9 @@ __all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "OpenAI", "AsyncOpenAI", "Client", "AsyncClient"] +TokenProvider = Callable[[], "str | dict[str, str]"] +AsyncTokenProvider = Callable[[], Awaitable["str | dict[str, str]"]] + class OpenAI(SyncAPIClient): # client options @@ -93,6 +97,7 @@ def __init__( self, *, api_key: str | None = None, + token_provider: TokenProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -124,13 +129,16 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and token_provider: + raise ValueError("The `api_key` and `token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and token_provider is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or token_provider client option must be set either by passing api_key or token_provider to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.token_provider = token_provider + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -163,6 +171,7 @@ def __init__( ) self._default_stream_cls = Stream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> Completions: @@ -279,14 +288,27 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + def refresh_auth_headers(self) -> None: + secret = self.token_provider() if self.token_provider else self.api_key + if not secret: + # if secret is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} + elif isinstance(secret, str): + self._auth_headers = {"Authorization": f"Bearer {secret}"} + else: + self._auth_headers = secret + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self.refresh_auth_headers() + return super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} + return self._auth_headers @property @override @@ -303,6 +325,7 @@ def copy( self, *, api_key: str | None = None, + token_provider: TokenProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -338,6 +361,10 @@ def copy( elif set_default_query is not None: params = set_default_query + token_provider = token_provider or self.token_provider + if token_provider is not None: + _extra_kwargs = {**_extra_kwargs, "token_provider": token_provider} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, @@ -412,6 +439,7 @@ def __init__( self, *, api_key: str | None = None, + token_provider: AsyncTokenProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -443,13 +471,16 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and token_provider: + raise ValueError("The `api_key` and `token_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and token_provider is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or token_provider client option must be set either by passing api_key or token_provider to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.token_provider = token_provider + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -482,6 +513,7 @@ def __init__( ) self._default_stream_cls = AsyncStream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> AsyncCompletions: @@ -598,14 +630,30 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + async def refresh_auth_headers(self) -> None: + if self.token_provider: + secret = await self.token_provider() + else: + secret = self.api_key + if not secret: + # if the secret is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} + elif isinstance(secret, str): + self._auth_headers = {"Authorization": f"Bearer {secret}"} + else: + self._auth_headers = secret + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self.refresh_auth_headers() + return await super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} + return self._auth_headers @property @override @@ -622,6 +670,7 @@ def copy( self, *, api_key: str | None = None, + token_provider: AsyncTokenProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -657,6 +706,10 @@ def copy( elif set_default_query is not None: params = set_default_query + token_provider = token_provider or self.token_provider + if token_provider is not None: + _extra_kwargs = {**_extra_kwargs, "token_provider": token_provider} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a994e4256c..95a3d3e9c3 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -255,7 +255,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -301,7 +301,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: @@ -536,7 +536,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -582,7 +582,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore async def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 8e1b558cf3..beff8eb582 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) diff --git a/tests/test_client.py b/tests/test_client.py index ccda50a7f0..18b1a64b1d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, Protocol, cast from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -41,6 +41,10 @@ api_key = "My API Key" +class MockRequestCall(Protocol): + request: httpx.Request + + def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) @@ -337,7 +341,9 @@ def test_default_headers_option(self) -> None: def test_validate_headers(self) -> None: client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) + assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -939,6 +945,68 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + def test_refresh_auth_headers_str_token(self) -> None: + client = OpenAI(base_url=base_url, token_provider=lambda: "test_bearer_token") + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + def test_refresh_auth_headers_dict(self) -> None: + client = OpenAI(base_url=base_url, token_provider=lambda: {"Authorization": "Bearer test_bearer_token"}) + client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} + + def test_refresh_auth_headers_key(self) -> None: + client = OpenAI(base_url=base_url, api_key="test_api_key") + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.respx() + def test_token_provider_refresh(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = OpenAI(base_url=base_url, token_provider=token_provider) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key, token_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + + def test_copy_auth(self) -> None: + client = OpenAI(base_url=base_url, token_provider=lambda: "test_bearer_token_1").copy( + token_provider=lambda: "test_bearer_token_2" + ) + client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key).copy(token_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1220,9 +1288,10 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: + async def test_validate_headers(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -1887,3 +1956,86 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_str_token_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_dict_async(self) -> None: + async def token_provider() -> dict[str, str]: + return {"Authorization": "Bearer test_bearer_token"} + + client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + await client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} + + @pytest.mark.asyncio + async def test_refresh_auth_headers_key_async(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.asyncio + @pytest.mark.respx() + async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key, token_provider=token_provider) + assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + + @pytest.mark.asyncio + async def test_copy_auth(self) -> None: + async def token_provider_1() -> str: + return "test_bearer_token_1" + + async def token_provider_2() -> str: + return "test_bearer_token_2" + + client = AsyncOpenAI(base_url=base_url, token_provider=token_provider_1).copy(token_provider=token_provider_2) + await client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key).copy(token_provider=token_provider) + assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 9c9a1addab..4ba2143e05 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -15,6 +15,7 @@ def reset_state() -> None: openai._reset_client() openai.api_key = None or "My API Key" + openai.token_provider = None openai.organization = None openai.project = None openai.webhook_secret = None @@ -97,6 +98,28 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client +def test_token_provider_str_option() -> None: + assert openai.token_provider is None + assert openai.completions._client.token_provider is None + + openai.token_provider = lambda: "foo" + + assert openai.token_provider() == "foo" + assert openai.completions._client.token_provider + assert openai.completions._client.token_provider() == "foo" + + +def test_token_provider_dict_option() -> None: + assert openai.token_provider is None + assert openai.completions._client.token_provider is None + + openai.token_provider = lambda: {"foo": "bar"} + + assert openai.token_provider() == {"foo": "bar"} + assert openai.completions._client.token_provider + assert openai.completions._client.token_provider() == {"foo": "bar"} + + import contextlib from typing import Iterator @@ -123,6 +146,27 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" +def test_only_token_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_type = None + openai.api_key = None + openai.token_provider = lambda: "example bearer token" + + assert type(openai.completions._client).__name__ == "_ModuleClient" + + +def test_both_api_key_and_token_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_key = "example API key" + openai.token_provider = lambda: "example bearer token" + + with pytest.raises( + ValueError, + match=r"The `api_key` and `token_provider` arguments are mutually exclusive", + ): + openai.completions._client # noqa: B018 + + def test_azure_api_key_env_without_api_version() -> None: with fresh_env(): openai.api_type = None From 12a6b3226eef9b79257121f6d68f7b68609004eb Mon Sep 17 00:00:00 2001 From: kristapratico Date: Fri, 15 Aug 2025 01:03:34 +0000 Subject: [PATCH 2/6] rename to auth_provider --- src/openai/__init__.py | 18 ++++++------- src/openai/_client.py | 50 +++++++++++++++++------------------ tests/test_client.py | 52 ++++++++++++++++++------------------- tests/test_module_client.py | 40 ++++++++++++++-------------- 4 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 9b8a29cc4d..e5e623464c 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -18,7 +18,7 @@ AsyncClient, AsyncOpenAI, AsyncStream, - TokenProvider, + AuthProvider, RequestOptions, ) from ._models import BaseModel @@ -130,7 +130,7 @@ api_key: str | None = None -token_provider: TokenProvider | None = None +auth_provider: AuthProvider | None = None organization: str | None = None @@ -180,14 +180,14 @@ def api_key(self, value: str | None) -> None: # type: ignore @property # type: ignore @override - def token_provider(self) -> TokenProvider | None: - return token_provider + def auth_provider(self) -> AuthProvider | None: + return auth_provider - @token_provider.setter # type: ignore - def token_provider(self, value: TokenProvider | None) -> None: # type: ignore - global token_provider + @auth_provider.setter # type: ignore + def auth_provider(self, value: AuthProvider | None) -> None: # type: ignore + global auth_provider - token_provider = value + auth_provider = value @property # type: ignore @override @@ -372,7 +372,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _ModuleClient( api_key=api_key, - token_provider=token_provider, + auth_provider=auth_provider, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index 40e16f957a..dc679feb49 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -74,8 +74,8 @@ __all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "OpenAI", "AsyncOpenAI", "Client", "AsyncClient"] -TokenProvider = Callable[[], "str | dict[str, str]"] -AsyncTokenProvider = Callable[[], Awaitable["str | dict[str, str]"]] +AuthProvider = Callable[[], "str | dict[str, str]"] +AsyncAuthProvider = Callable[[], Awaitable["str | dict[str, str]"]] class OpenAI(SyncAPIClient): @@ -97,7 +97,7 @@ def __init__( self, *, api_key: str | None = None, - token_provider: TokenProvider | None = None, + auth_provider: AuthProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -129,15 +129,15 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ - if api_key and token_provider: - raise ValueError("The `api_key` and `token_provider` arguments are mutually exclusive") + if api_key and auth_provider: + raise ValueError("The `api_key` and `auth_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None and token_provider is None: + if api_key is None and auth_provider is None: raise OpenAIError( - "The api_key or token_provider client option must be set either by passing api_key or token_provider to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or auth_provider client option must be set either by passing api_key or auth_provider to the client or by setting the OPENAI_API_KEY environment variable" ) - self.token_provider = token_provider + self.auth_provider = auth_provider self.api_key = api_key or "" if organization is None: @@ -289,7 +289,7 @@ def qs(self) -> Querystring: return Querystring(array_format="brackets") def refresh_auth_headers(self) -> None: - secret = self.token_provider() if self.token_provider else self.api_key + secret = self.auth_provider() if self.auth_provider else self.api_key if not secret: # if secret is an empty string, encoding the header will fail # so we set it to an empty dict @@ -325,7 +325,7 @@ def copy( self, *, api_key: str | None = None, - token_provider: TokenProvider | None = None, + auth_provider: AuthProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -361,9 +361,9 @@ def copy( elif set_default_query is not None: params = set_default_query - token_provider = token_provider or self.token_provider - if token_provider is not None: - _extra_kwargs = {**_extra_kwargs, "token_provider": token_provider} + auth_provider = auth_provider or self.auth_provider + if auth_provider is not None: + _extra_kwargs = {**_extra_kwargs, "auth_provider": auth_provider} http_client = http_client or self._client return self.__class__( @@ -439,7 +439,7 @@ def __init__( self, *, api_key: str | None = None, - token_provider: AsyncTokenProvider | None = None, + auth_provider: AsyncAuthProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -471,15 +471,15 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ - if api_key and token_provider: - raise ValueError("The `api_key` and `token_provider` arguments are mutually exclusive") + if api_key and auth_provider: + raise ValueError("The `api_key` and `auth_provider` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None and token_provider is None: + if api_key is None and auth_provider is None: raise OpenAIError( - "The api_key or token_provider client option must be set either by passing api_key or token_provider to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or auth_provider client option must be set either by passing api_key or auth_provider to the client or by setting the OPENAI_API_KEY environment variable" ) - self.token_provider = token_provider + self.auth_provider = auth_provider self.api_key = api_key or "" if organization is None: @@ -631,8 +631,8 @@ def qs(self) -> Querystring: return Querystring(array_format="brackets") async def refresh_auth_headers(self) -> None: - if self.token_provider: - secret = await self.token_provider() + if self.auth_provider: + secret = await self.auth_provider() else: secret = self.api_key if not secret: @@ -670,7 +670,7 @@ def copy( self, *, api_key: str | None = None, - token_provider: AsyncTokenProvider | None = None, + auth_provider: AsyncAuthProvider | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -706,9 +706,9 @@ def copy( elif set_default_query is not None: params = set_default_query - token_provider = token_provider or self.token_provider - if token_provider is not None: - _extra_kwargs = {**_extra_kwargs, "token_provider": token_provider} + auth_provider = auth_provider or self.auth_provider + if auth_provider is not None: + _extra_kwargs = {**_extra_kwargs, "auth_provider": auth_provider} http_client = http_client or self._client return self.__class__( diff --git a/tests/test_client.py b/tests/test_client.py index 18b1a64b1d..91ad7e06ba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -946,12 +946,12 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" def test_refresh_auth_headers_str_token(self) -> None: - client = OpenAI(base_url=base_url, token_provider=lambda: "test_bearer_token") + client = OpenAI(base_url=base_url, auth_provider=lambda: "test_bearer_token") client.refresh_auth_headers() assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" def test_refresh_auth_headers_dict(self) -> None: - client = OpenAI(base_url=base_url, token_provider=lambda: {"Authorization": "Bearer test_bearer_token"}) + client = OpenAI(base_url=base_url, auth_provider=lambda: {"Authorization": "Bearer test_bearer_token"}) client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} @@ -961,7 +961,7 @@ def test_refresh_auth_headers_key(self) -> None: assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.respx() - def test_token_provider_refresh(self, respx_mock: MockRouter) -> None: + def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: respx_mock.post(base_url + "/chat/completions").mock( side_effect=[ httpx.Response(500, json={"error": "server error"}), @@ -971,7 +971,7 @@ def test_token_provider_refresh(self, respx_mock: MockRouter) -> None: counter = 0 - def token_provider() -> str: + def auth_provider() -> str: nonlocal counter counter += 1 @@ -981,7 +981,7 @@ def token_provider() -> str: return "second" - client = OpenAI(base_url=base_url, token_provider=token_provider) + client = OpenAI(base_url=base_url, auth_provider=auth_provider) client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) @@ -992,20 +992,20 @@ def token_provider() -> str: def test_auth_mutually_exclusive(self) -> None: with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key, token_provider=lambda: "test_bearer_token") - assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + OpenAI(base_url=base_url, api_key=api_key, auth_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" def test_copy_auth(self) -> None: - client = OpenAI(base_url=base_url, token_provider=lambda: "test_bearer_token_1").copy( - token_provider=lambda: "test_bearer_token_2" + client = OpenAI(base_url=base_url, auth_provider=lambda: "test_bearer_token_1").copy( + auth_provider=lambda: "test_bearer_token_2" ) client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} def test_copy_auth_mutually_exclusive(self) -> None: with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key).copy(token_provider=lambda: "test_bearer_token") - assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + OpenAI(base_url=base_url, api_key=api_key).copy(auth_provider=lambda: "test_bearer_token") + assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" class TestAsyncOpenAI: @@ -1959,19 +1959,19 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: @pytest.mark.asyncio async def test_refresh_auth_headers_str_token_async(self) -> None: - async def token_provider() -> str: + async def auth_provider() -> str: return "test_bearer_token" - client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) await client.refresh_auth_headers() assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" @pytest.mark.asyncio async def test_refresh_auth_headers_dict_async(self) -> None: - async def token_provider() -> dict[str, str]: + async def auth_provider() -> dict[str, str]: return {"Authorization": "Bearer test_bearer_token"} - client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) await client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} @@ -1993,7 +1993,7 @@ async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: counter = 0 - async def token_provider() -> str: + async def auth_provider() -> str: nonlocal counter counter += 1 @@ -2003,7 +2003,7 @@ async def token_provider() -> str: return "second" - client = AsyncOpenAI(base_url=base_url, token_provider=token_provider) + client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) await client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) @@ -2013,29 +2013,29 @@ async def token_provider() -> str: assert calls[1].request.headers.get("Authorization") == "Bearer second" def test_auth_mutually_exclusive_async(self) -> None: - async def token_provider() -> str: + async def auth_provider() -> str: return "test_bearer_token" with pytest.raises(ValueError) as exc_info: - AsyncOpenAI(base_url=base_url, api_key=api_key, token_provider=token_provider) - assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + AsyncOpenAI(base_url=base_url, api_key=api_key, auth_provider=auth_provider) + assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" @pytest.mark.asyncio async def test_copy_auth(self) -> None: - async def token_provider_1() -> str: + async def auth_provider_1() -> str: return "test_bearer_token_1" - async def token_provider_2() -> str: + async def auth_provider_2() -> str: return "test_bearer_token_2" - client = AsyncOpenAI(base_url=base_url, token_provider=token_provider_1).copy(token_provider=token_provider_2) + client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider_1).copy(auth_provider=auth_provider_2) await client.refresh_auth_headers() assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} def test_copy_auth_mutually_exclusive_async(self) -> None: - async def token_provider() -> str: + async def auth_provider() -> str: return "test_bearer_token" with pytest.raises(ValueError) as exc_info: - AsyncOpenAI(base_url=base_url, api_key=api_key).copy(token_provider=token_provider) - assert str(exc_info.value) == "The `api_key` and `token_provider` arguments are mutually exclusive" + AsyncOpenAI(base_url=base_url, api_key=api_key).copy(auth_provider=auth_provider) + assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 4ba2143e05..862fe713db 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -15,7 +15,7 @@ def reset_state() -> None: openai._reset_client() openai.api_key = None or "My API Key" - openai.token_provider = None + openai.auth_provider = None openai.organization = None openai.project = None openai.webhook_secret = None @@ -98,26 +98,26 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client -def test_token_provider_str_option() -> None: - assert openai.token_provider is None - assert openai.completions._client.token_provider is None +def test_auth_provider_str_option() -> None: + assert openai.auth_provider is None + assert openai.completions._client.auth_provider is None - openai.token_provider = lambda: "foo" + openai.auth_provider = lambda: "foo" - assert openai.token_provider() == "foo" - assert openai.completions._client.token_provider - assert openai.completions._client.token_provider() == "foo" + assert openai.auth_provider() == "foo" + assert openai.completions._client.auth_provider + assert openai.completions._client.auth_provider() == "foo" -def test_token_provider_dict_option() -> None: - assert openai.token_provider is None - assert openai.completions._client.token_provider is None +def test_auth_provider_dict_option() -> None: + assert openai.auth_provider is None + assert openai.completions._client.auth_provider is None - openai.token_provider = lambda: {"foo": "bar"} + openai.auth_provider = lambda: {"foo": "bar"} - assert openai.token_provider() == {"foo": "bar"} - assert openai.completions._client.token_provider - assert openai.completions._client.token_provider() == {"foo": "bar"} + assert openai.auth_provider() == {"foo": "bar"} + assert openai.completions._client.auth_provider + assert openai.completions._client.auth_provider() == {"foo": "bar"} import contextlib @@ -146,23 +146,23 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_only_token_provider_in_openai_api() -> None: +def test_only_auth_provider_in_openai_api() -> None: with fresh_env(): openai.api_type = None openai.api_key = None - openai.token_provider = lambda: "example bearer token" + openai.auth_provider = lambda: "example bearer token" assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_both_api_key_and_token_provider_in_openai_api() -> None: +def test_both_api_key_and_auth_provider_in_openai_api() -> None: with fresh_env(): openai.api_key = "example API key" - openai.token_provider = lambda: "example bearer token" + openai.auth_provider = lambda: "example bearer token" with pytest.raises( ValueError, - match=r"The `api_key` and `token_provider` arguments are mutually exclusive", + match=r"The `api_key` and `auth_provider` arguments are mutually exclusive", ): openai.completions._client # noqa: B018 From e8591e242b8da731f788d0e27bb5385d4a091eef Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Tue, 19 Aug 2025 12:10:24 -0700 Subject: [PATCH 3/6] Change auth handler to operate on FinalRequestOptions --- src/openai/_client.py | 53 ++++++-------------- tests/test_client.py | 114 ++++++++++++++++++++---------------------- 2 files changed, 71 insertions(+), 96 deletions(-) diff --git a/src/openai/_client.py b/src/openai/_client.py index dc679feb49..29e623e963 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -74,8 +74,8 @@ __all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "OpenAI", "AsyncOpenAI", "Client", "AsyncClient"] -AuthProvider = Callable[[], "str | dict[str, str]"] -AsyncAuthProvider = Callable[[], Awaitable["str | dict[str, str]"]] +AuthProvider = Callable[[FinalRequestOptions], FinalRequestOptions] +AsyncAuthProvider = Callable[[FinalRequestOptions], Awaitable[FinalRequestOptions]] class OpenAI(SyncAPIClient): @@ -171,7 +171,6 @@ def __init__( ) self._default_stream_cls = Stream - self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> Completions: @@ -288,27 +287,19 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") - def refresh_auth_headers(self) -> None: - secret = self.auth_provider() if self.auth_provider else self.api_key - if not secret: - # if secret is an empty string, encoding the header will fail - # so we set it to an empty dict - # this is to avoid sending an invalid Authorization header - self._auth_headers = {} - elif isinstance(secret, str): - self._auth_headers = {"Authorization": f"Bearer {secret}"} - else: - self._auth_headers = secret - @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - self.refresh_auth_headers() - return super()._prepare_options(options) + options = super()._prepare_options(options) + return self.auth_provider(options) if self.auth_provider else options @property @override def auth_headers(self) -> dict[str, str]: - return self._auth_headers + if self.api_key: + return { + "Authorization": f"Bearer {self.api_key}" + } + return {} @property @override @@ -513,7 +504,6 @@ def __init__( ) self._default_stream_cls = AsyncStream - self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> AsyncCompletions: @@ -630,30 +620,19 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") - async def refresh_auth_headers(self) -> None: - if self.auth_provider: - secret = await self.auth_provider() - else: - secret = self.api_key - if not secret: - # if the secret is an empty string, encoding the header will fail - # so we set it to an empty dict - # this is to avoid sending an invalid Authorization header - self._auth_headers = {} - elif isinstance(secret, str): - self._auth_headers = {"Authorization": f"Bearer {secret}"} - else: - self._auth_headers = secret - @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: - await self.refresh_auth_headers() - return await super()._prepare_options(options) + options = await super()._prepare_options(options) + return await self.auth_provider(options) if self.auth_provider else options @property @override def auth_headers(self) -> dict[str, str]: - return self._auth_headers + if self.api_key: + return { + "Authorization": f"Bearer {self.api_key}" + } + return {} @property @override diff --git a/tests/test_client.py b/tests/test_client.py index 91ad7e06ba..c5fb7f2118 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, Protocol, cast +from typing import Any, Union, Protocol, cast, Callable from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -24,7 +24,9 @@ from openai import OpenAI, AsyncOpenAI, APIResponseValidationError from openai._types import Omit from openai._models import BaseModel, FinalRequestOptions +from openai._compat import model_copy from openai._streaming import Stream, AsyncStream +from openai._utils import is_given from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError from openai._base_client import ( DEFAULT_TIMEOUT, @@ -44,6 +46,27 @@ class MockRequestCall(Protocol): request: httpx.Request +def mock_auth_provider(token: str = 'dummy', *, additional: dict[str, str] = {}) -> Callable[[FinalRequestOptions], FinalRequestOptions]: + """ + Mock auth provider that returns a FinalRequestOptions with the Authorization header set. + """ + def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: + """ + Mock auth provider that returns a FinalRequestOptions with the Authorization header set. + """ + updated = model_copy(options) + headers = { **updated.headers } if is_given(updated.headers) else {} + headers['Authorization'] = f"Bearer {token}" + updated.headers = { **additional, **headers } + return updated + + return auth_provider + +def async_mock_auth_provider(token: str = 'dummy', *, additional: dict[str, str] = {}): + async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: + sync_auth_provider = mock_auth_provider(token, additional=additional) + return sync_auth_provider(options) + return auth_provider def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -945,21 +968,6 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" - def test_refresh_auth_headers_str_token(self) -> None: - client = OpenAI(base_url=base_url, auth_provider=lambda: "test_bearer_token") - client.refresh_auth_headers() - assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" - - def test_refresh_auth_headers_dict(self) -> None: - client = OpenAI(base_url=base_url, auth_provider=lambda: {"Authorization": "Bearer test_bearer_token"}) - client.refresh_auth_headers() - assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} - - def test_refresh_auth_headers_key(self) -> None: - client = OpenAI(base_url=base_url, api_key="test_api_key") - client.refresh_auth_headers() - assert client.auth_headers.get("Authorization") == "Bearer test_api_key" - @pytest.mark.respx() def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: respx_mock.post(base_url + "/chat/completions").mock( @@ -971,15 +979,19 @@ def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: counter = 0 - def auth_provider() -> str: + def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: nonlocal counter counter += 1 + updated = model_copy(options) + updated_headers = { **options.headers} if is_given(options.headers) else {} if counter == 1: - return "first" - - return "second" + updated_headers["Authorization"] = "Bearer first" + else: + updated_headers["Authorization"] = "Bearer second" + updated.headers = updated_headers + return updated client = OpenAI(base_url=base_url, auth_provider=auth_provider) client.chat.completions.create(messages=[], model="gpt-4") @@ -992,19 +1004,19 @@ def auth_provider() -> str: def test_auth_mutually_exclusive(self) -> None: with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key, auth_provider=lambda: "test_bearer_token") + OpenAI(base_url=base_url, api_key=api_key, auth_provider=mock_auth_provider()) assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" def test_copy_auth(self) -> None: - client = OpenAI(base_url=base_url, auth_provider=lambda: "test_bearer_token_1").copy( - auth_provider=lambda: "test_bearer_token_2" + client = OpenAI(base_url=base_url, auth_provider=mock_auth_provider("Bearer test_bearer_token_1")).copy( + auth_provider=mock_auth_provider("test_bearer_token_2") ) - client.refresh_auth_headers() - assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + assert options.headers == {"Authorization": "Bearer test_bearer_token_2"} def test_copy_auth_mutually_exclusive(self) -> None: with pytest.raises(ValueError) as exc_info: - OpenAI(base_url=base_url, api_key=api_key).copy(auth_provider=lambda: "test_bearer_token") + OpenAI(base_url=base_url, api_key=api_key).copy(auth_provider=mock_auth_provider()) assert str(exc_info.value) == "The `api_key` and `auth_provider` arguments are mutually exclusive" @@ -1957,29 +1969,12 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" - @pytest.mark.asyncio - async def test_refresh_auth_headers_str_token_async(self) -> None: - async def auth_provider() -> str: - return "test_bearer_token" - - client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) - await client.refresh_auth_headers() - assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" - @pytest.mark.asyncio async def test_refresh_auth_headers_dict_async(self) -> None: - async def auth_provider() -> dict[str, str]: - return {"Authorization": "Bearer test_bearer_token"} + client = AsyncOpenAI(base_url=base_url, auth_provider=async_mock_auth_provider('test_bearer_token', additional = { 'Second': 'Value'})) + options = await client._prepare_options(FinalRequestOptions(method='GET', url='/foo')) + assert options.headers == { "Authorization": "Bearer test_bearer_token", "Second": "Value"} - client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) - await client.refresh_auth_headers() - assert client.auth_headers == {"Authorization": "Bearer test_bearer_token"} - - @pytest.mark.asyncio - async def test_refresh_auth_headers_key_async(self) -> None: - client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") - await client.refresh_auth_headers() - assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.asyncio @pytest.mark.respx() @@ -1993,15 +1988,19 @@ async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: counter = 0 - async def auth_provider() -> str: + async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: nonlocal counter counter += 1 + updated_headers = { **options.headers} if is_given(options.headers) else {} + updated = model_copy(options) if counter == 1: - return "first" - - return "second" + updated_headers["Authorization"] = "Bearer first" + else: + updated_headers["Authorization"] = "Bearer second" + updated.headers = updated_headers + return updated client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) await client.chat.completions.create(messages=[], model="gpt-4") @@ -2013,8 +2012,8 @@ async def auth_provider() -> str: assert calls[1].request.headers.get("Authorization") == "Bearer second" def test_auth_mutually_exclusive_async(self) -> None: - async def auth_provider() -> str: - return "test_bearer_token" + async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: + return options with pytest.raises(ValueError) as exc_info: AsyncOpenAI(base_url=base_url, api_key=api_key, auth_provider=auth_provider) @@ -2022,15 +2021,12 @@ async def auth_provider() -> str: @pytest.mark.asyncio async def test_copy_auth(self) -> None: - async def auth_provider_1() -> str: - return "test_bearer_token_1" - - async def auth_provider_2() -> str: - return "test_bearer_token_2" + auth_provider_1 = async_mock_auth_provider('First') + auth_provider_2 = async_mock_auth_provider('Second') client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider_1).copy(auth_provider=auth_provider_2) - await client.refresh_auth_headers() - assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + options = await client._prepare_options(FinalRequestOptions(method='GET', url='/foo')) + assert options.headers.get("Authorization") == "Bearer Second" def test_copy_auth_mutually_exclusive_async(self) -> None: async def auth_provider() -> str: From 7a34ac2091d8c0db60ba29b471927b365100339d Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Tue, 19 Aug 2025 18:56:01 -0700 Subject: [PATCH 4/6] Add AuthProvider class with support for url rewrite, arbitrary auth headers, query parameters and challenge auth --- examples/azure_ad.py | 36 ++++--- src/openai/__init__.py | 5 +- src/openai/_client.py | 46 ++++++++- src/openai/lib/azure.py | 43 +++++++- .../resources/beta/realtime/realtime.py | 36 ++++--- tests/test_client.py | 97 +++++++++---------- 6 files changed, 175 insertions(+), 88 deletions(-) diff --git a/examples/azure_ad.py b/examples/azure_ad.py index 67e2f23713..a43009aba4 100755 --- a/examples/azure_ad.py +++ b/examples/azure_ad.py @@ -1,28 +1,24 @@ import asyncio -from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI, AzureADTokenProvider, AsyncAzureADTokenProvider - -scopes = "https://cognitiveservices.azure.com/.default" - -# May change in the future -# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning -api_version = "2023-07-01-preview" +from openai.lib.azure import OpenAI, AsyncOpenAI, AzureAuth, AsyncAzureAuth, AzureADTokenProvider, AsyncAzureADTokenProvider # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource -endpoint = "https://my-resource.openai.azure.com" +endpoint = "https://my-resource.openai.azure.com" and 'https://johan-mczd33pe-swedencentral.cognitiveservices.azure.com/openai/v1' -deployment_name = "deployment-name" # e.g. gpt-35-instant +deployment_name = "deployment-name" and 'gpt-4.1-nano' # e.g. gpt-35-instant def sync_main() -> None: from azure.identity import DefaultAzureCredential, get_bearer_token_provider - token_provider: AzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), scopes) + token_provider: AzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), AzureAuth.DEFAULT_SCOPE) - client = AzureOpenAI( - api_version=api_version, - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider, + client = OpenAI( + base_url=endpoint, + auth_provider=AzureAuth(token_provider), + default_query={ # Temporary requirement to specify api version - will be removed once v1 routes go GA + 'api-version': 'preview' + } ) completion = client.chat.completions.create( @@ -41,12 +37,14 @@ def sync_main() -> None: async def async_main() -> None: from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider - token_provider: AsyncAzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), scopes) + token_provider: AsyncAzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), AsyncAzureAuth.DEFAULT_SCOPE) - client = AsyncAzureOpenAI( - api_version=api_version, - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider, + client = AsyncOpenAI( + base_url=endpoint, + auth_provider=AsyncAzureAuth(token_provider), + default_query={ # Temporary requirement to specify api version - will be removed once v1 routes go GA + 'api-version': 'preview' + } ) completion = await client.chat.completions.create( diff --git a/src/openai/__init__.py b/src/openai/__init__.py index e5e623464c..7b8e3a153d 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -19,6 +19,7 @@ AsyncOpenAI, AsyncStream, AuthProvider, + AsyncAuthProvider, RequestOptions, ) from ._models import BaseModel @@ -83,6 +84,8 @@ "AsyncStream", "OpenAI", "AsyncOpenAI", + "AuthProvider", + "AsyncAuthProvider", "file_from_path", "BaseModel", "DEFAULT_TIMEOUT", @@ -98,7 +101,7 @@ from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool from .version import VERSION as VERSION -from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI +from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI, AzureAuth as AzureAuth, AsyncAzureAuth as AsyncAzureAuth from .lib._old_api import * from .lib.streaming import ( AssistantEventHandler as AssistantEventHandler, diff --git a/src/openai/_client.py b/src/openai/_client.py index 29e623e963..28b2c5f435 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable +from typing import TYPE_CHECKING, Any, Union, Mapping, Protocol from typing_extensions import Self, override import httpx @@ -14,7 +14,9 @@ NOT_GIVEN, Omit, Timeout, + Headers, NotGiven, + NotGivenOr, Transport, ProxiesTypes, RequestOptions, @@ -74,8 +76,23 @@ __all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "OpenAI", "AsyncOpenAI", "Client", "AsyncClient"] -AuthProvider = Callable[[FinalRequestOptions], FinalRequestOptions] -AsyncAuthProvider = Callable[[FinalRequestOptions], Awaitable[FinalRequestOptions]] +class AuthProvider(Protocol): + + def do_auth(self, *, url: httpx.URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + """Perform authentication for the request. + + This method should be overridden by subclasses to implement specific authentication logic. + """ + raise NotImplementedError("Subclasses must implement this method.") + +class AsyncAuthProvider(Protocol): + + async def do_auth(self, *, url: httpx.URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + """Perform authentication for the request. + + This method should be overridden by subclasses to implement specific authentication logic. + """ + raise NotImplementedError("Subclasses must implement this method.") class OpenAI(SyncAPIClient): @@ -290,7 +307,17 @@ def qs(self) -> Querystring: @override def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: options = super()._prepare_options(options) - return self.auth_provider(options) if self.auth_provider else options + + if self.auth_provider: + url, headers, params, _ = self.auth_provider.do_auth( + url = options.url, + headers = options.headers, + params = options.params + ) + options.url = url + options.headers = headers + options.params = params + return options @property @override @@ -623,7 +650,16 @@ def qs(self) -> Querystring: @override async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: options = await super()._prepare_options(options) - return await self.auth_provider(options) if self.auth_provider else options + if self.auth_provider: + url, headers, params, _ = await self.auth_provider.do_auth( + url = options.url, + headers = options.headers, + params = options.params + ) + options.url = url + options.headers = headers + options.params = params + return options @property @override diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 95a3d3e9c3..20f2b8a1cf 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,12 +2,13 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload +from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload, Protocol from typing_extensions import Self, override import httpx -from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven +from openai import AuthProvider, AsyncAuthProvider +from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven, NotGivenOr, Headers from .._utils import is_given, is_mapping from .._client import OpenAI, AsyncOpenAI from .._compat import model_copy @@ -35,6 +36,44 @@ _HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient]) _DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) +class TokenCredentialLike(Protocol): + + def get_token(self) -> str | Awaitable[str]: + ... + +class AzureAuth(AuthProvider): + + DEFAULT_SCOPE = "https://cognitiveservices.azure.com/.default" + + def __init__(self, azure_ad_token_provider: AzureADTokenProvider): + self.azure_ad_token_provider = azure_ad_token_provider + + @override + def do_auth(self, *, url: httpx.URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + headers = { **headers } if is_given(headers) else {} + headers.setdefault('Authorization', f'Bearer {self.get_token()}') + return url, headers, params, cookies + + def get_token(self) -> str: + return self.azure_ad_token_provider() + +class AsyncAzureAuth(AsyncAuthProvider): + + DEFAULT_SCOPE = "https://cognitiveservices.azure.com/.default" + + def __init__(self, azure_ad_token_provider: AsyncAzureADTokenProvider): + self.azure_ad_token_provider = azure_ad_token_provider + + @override + async def do_auth(self, *, url: httpx.URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + headers = { **headers } if is_given(headers) else {} + headers.setdefault('Authorization', f'Bearer {await self.get_token()}') + return url, headers, params, cookies + + async def get_token(self) -> str: + if isinstance(self.azure_ad_token_provider, str): + return self.azure_ad_token_provider + return await self.azure_ad_token_provider() # we need to use a sentinel API key value for Azure AD # as we don't want to make the `api_key` in the main client Optional diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index beff8eb582..af55755efd 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -22,6 +22,7 @@ from ...._types import NOT_GIVEN, Query, Headers, NotGiven from ...._utils import ( is_azure_client, + is_given, maybe_transform, strip_not_given, async_maybe_transform, @@ -358,18 +359,32 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query - await self.__client.refresh_auth_headers() - auth_headers = self.__client.auth_headers - if is_async_azure_client(self.__client): - url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) - else: - url = self._prepare_url().copy_with( - params={ + + url = self._prepare_url() + if self.__client.auth_provider: + url, headers, params, _ = await self.__client.auth_provider.do_auth( + url = url, + headers = self.__client.auth_headers, + params = { **self.__client.base_url.params, - "model": self.__model, **extra_query, }, ) + else: + headers, params, = ( + self.__client.auth_headers, + { + **self.__client.base_url.params, + **extra_query, + } + ) + url = url.copy_with( + params={ + "model": self.__model, + **(params if is_given(params) else {}), + }, + ) + log.debug("Connecting to %s", url) if self.__websocket_connection_options: log.debug("Connection options: %s", self.__websocket_connection_options) @@ -380,7 +395,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: user_agent_header=self.__client.user_agent, additional_headers=_merge_mappings( { - **auth_headers, + **headers, "OpenAI-Beta": "realtime=v1", }, self.__extra_headers, @@ -541,8 +556,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query - self.__client.refresh_auth_headers() - auth_headers = self.__client.auth_headers + url = self._prepare_url() if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) else: diff --git a/tests/test_client.py b/tests/test_client.py index c5fb7f2118..a7e5fecfc8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,18 +11,18 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, Protocol, cast, Callable +from typing import Any, Union, Protocol, cast, Mapping from textwrap import dedent from unittest import mock -from typing_extensions import Literal +from typing_extensions import Literal, override import httpx import pytest from respx import MockRouter from pydantic import ValidationError -from openai import OpenAI, AsyncOpenAI, APIResponseValidationError -from openai._types import Omit +from openai import OpenAI, AsyncOpenAI, APIResponseValidationError, AuthProvider, AsyncAuthProvider, NOT_GIVEN +from openai._types import Omit, NotGivenOr, Headers from openai._models import BaseModel, FinalRequestOptions from openai._compat import model_copy from openai._streaming import Stream, AsyncStream @@ -46,27 +46,31 @@ class MockRequestCall(Protocol): request: httpx.Request -def mock_auth_provider(token: str = 'dummy', *, additional: dict[str, str] = {}) -> Callable[[FinalRequestOptions], FinalRequestOptions]: +def mock_auth_provider(token: str = 'dummy', *, additional: dict[str, str] = {}) -> AuthProvider: """ Mock auth provider that returns a FinalRequestOptions with the Authorization header set. """ - def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: - """ - Mock auth provider that returns a FinalRequestOptions with the Authorization header set. - """ - updated = model_copy(options) - headers = { **updated.headers } if is_given(updated.headers) else {} - headers['Authorization'] = f"Bearer {token}" - updated.headers = { **additional, **headers } - return updated + class MockAuthProvider(AuthProvider): + + @override + def do_auth(self, *, url: str, headers: Mapping[str, str], params: dict[str, str], cookies: Any = NOT_GIVEN) -> tuple[str, Mapping[str, str], dict[str, str], Any]: + headers = { **(headers if is_given(headers) else {}), **additional } + headers.setdefault('Authorization', f'Bearer {token}') + return url, headers, params, cookies - return auth_provider + return MockAuthProvider() def async_mock_auth_provider(token: str = 'dummy', *, additional: dict[str, str] = {}): - async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: - sync_auth_provider = mock_auth_provider(token, additional=additional) - return sync_auth_provider(options) - return auth_provider + + class MockAuthProvider(AsyncAuthProvider): + + @override + async def do_auth(self, *, url: str, headers: NotGivenOr[Headers] = NOT_GIVEN, params: dict[str, str], cookies: Any = NOT_GIVEN) -> tuple[str, NotGivenOr[Headers], NotGivenOr[dict[str, str]], Any]: + headers = { **(headers if is_given(headers) else {}), **additional } + headers.setdefault('Authorization', f'Bearer {token}') + return url, headers, params, cookies + + return MockAuthProvider() def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -977,30 +981,26 @@ def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: ] ) - counter = 0 + class CountingAuthProvider(AuthProvider): - def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: - nonlocal counter + def __init__(self): + self.counter = 0 - counter += 1 + @override + def do_auth(self, *, url: str, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN) -> tuple[str, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + self.counter += 1 + headers = {**headers} if is_given(headers) else {} + headers.setdefault('Authorization', f'Bearer {self.counter}') + return url, headers, params, cookies - updated = model_copy(options) - updated_headers = { **options.headers} if is_given(options.headers) else {} - if counter == 1: - updated_headers["Authorization"] = "Bearer first" - else: - updated_headers["Authorization"] = "Bearer second" - updated.headers = updated_headers - return updated - - client = OpenAI(base_url=base_url, auth_provider=auth_provider) + client = OpenAI(base_url=base_url, auth_provider=CountingAuthProvider()) client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) assert len(calls) == 2 - assert calls[0].request.headers.get("Authorization") == "Bearer first" - assert calls[1].request.headers.get("Authorization") == "Bearer second" + assert calls[0].request.headers.get("Authorization") == "Bearer 1" + assert calls[1].request.headers.get("Authorization") == "Bearer 2" def test_auth_mutually_exclusive(self) -> None: with pytest.raises(ValueError) as exc_info: @@ -1986,30 +1986,27 @@ async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: ] ) - counter = 0 + class CountingAuthProvider(AsyncAuthProvider): - async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: - nonlocal counter + def __init__(self): + self.counter = 0 - counter += 1 + @override + async def do_auth(self, *, url: str, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN) -> tuple[str, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]: + self.counter += 1 + headers = {**headers} if is_given(headers) else {} + headers.setdefault('Authorization', f'Bearer {self.counter}') + return url, headers, params, cookies - updated_headers = { **options.headers} if is_given(options.headers) else {} - updated = model_copy(options) - if counter == 1: - updated_headers["Authorization"] = "Bearer first" - else: - updated_headers["Authorization"] = "Bearer second" - updated.headers = updated_headers - return updated - client = AsyncOpenAI(base_url=base_url, auth_provider=auth_provider) + client = AsyncOpenAI(base_url=base_url, auth_provider=CountingAuthProvider()) await client.chat.completions.create(messages=[], model="gpt-4") calls = cast("list[MockRequestCall]", respx_mock.calls) assert len(calls) == 2 - assert calls[0].request.headers.get("Authorization") == "Bearer first" - assert calls[1].request.headers.get("Authorization") == "Bearer second" + assert calls[0].request.headers.get("Authorization") == "Bearer 1" + assert calls[1].request.headers.get("Authorization") == "Bearer 2" def test_auth_mutually_exclusive_async(self) -> None: async def auth_provider(options: FinalRequestOptions) -> FinalRequestOptions: From 4c997e4ec6ff947aff4992444ea3a7643feb48d2 Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Wed, 20 Aug 2025 10:10:37 -0700 Subject: [PATCH 5/6] Undo changes to auth_headers property --- src/openai/_client.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/openai/_client.py b/src/openai/_client.py index 28b2c5f435..7a3e83ba87 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -322,12 +322,12 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: @property @override def auth_headers(self) -> dict[str, str]: - if self.api_key: - return { - "Authorization": f"Bearer {self.api_key}" - } - return {} - + api_key = self.api_key + if not api_key: + # if the api key is an empty string, encoding the header will fail + return {} + return {"Authorization": f"Bearer {api_key}"} + @property @override def default_headers(self) -> dict[str, str | Omit]: @@ -664,11 +664,11 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp @property @override def auth_headers(self) -> dict[str, str]: - if self.api_key: - return { - "Authorization": f"Bearer {self.api_key}" - } - return {} + api_key = self.api_key + if not api_key: + # if the api key is an empty string, encoding the header will fail + return {} + return {"Authorization": f"Bearer {api_key}"} @property @override From ceac0e7826a3581af130f87f6602e684c7fc360c Mon Sep 17 00:00:00 2001 From: Johan Stenberg Date: Wed, 20 Aug 2025 10:12:47 -0700 Subject: [PATCH 6/6] Undo accidentally changed example endpoint --- examples/azure_ad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/azure_ad.py b/examples/azure_ad.py index a43009aba4..e1824fb78d 100755 --- a/examples/azure_ad.py +++ b/examples/azure_ad.py @@ -3,9 +3,9 @@ from openai.lib.azure import OpenAI, AsyncOpenAI, AzureAuth, AsyncAzureAuth, AzureADTokenProvider, AsyncAzureADTokenProvider # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource -endpoint = "https://my-resource.openai.azure.com" and 'https://johan-mczd33pe-swedencentral.cognitiveservices.azure.com/openai/v1' +endpoint = "https://my-resource.openai.azure.com" -deployment_name = "deployment-name" and 'gpt-4.1-nano' # e.g. gpt-35-instant +deployment_name = "deployment-name" # e.g. gpt-35-instant def sync_main() -> None: