Skip to content

Commit 3fc9edf

Browse files
committed
Add refresh auth headers (sync and async) as alternate approach to allow bearer tokens to be updated
Allow api_key to be a callable to enable refresh of keys/tokens.
1 parent 4e28a42 commit 3fc9edf

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

src/openai/_client.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping
6+
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
77
from typing_extensions import Self, override
88

99
import httpx
1010

11+
from openai._models import FinalRequestOptions
12+
1113
from . import _exceptions
1214
from ._qs import Querystring
1315
from ._types import (
@@ -95,6 +97,7 @@ def __init__(
9597
self,
9698
*,
9799
api_key: str | None = None,
100+
bearer_token_provider: Callable[[], str] | None = None,
98101
organization: str | None = None,
99102
project: str | None = None,
100103
webhook_secret: str | None = None,
@@ -128,11 +131,12 @@ def __init__(
128131
"""
129132
if api_key is None:
130133
api_key = os.environ.get("OPENAI_API_KEY")
131-
if api_key is None:
134+
if api_key is None and bearer_token_provider is None:
132135
raise OpenAIError(
133136
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
134137
)
135-
self.api_key = api_key
138+
self.bearer_token_provider = bearer_token_provider
139+
self.api_key = api_key or ''
136140

137141
if organization is None:
138142
organization = os.environ.get("OPENAI_ORG_ID")
@@ -165,6 +169,7 @@ def __init__(
165169
)
166170

167171
self._default_stream_cls = Stream
172+
self._auth_headers: dict[str, str] = {}
168173

169174
@cached_property
170175
def completions(self) -> Completions:
@@ -281,21 +286,26 @@ def with_raw_response(self) -> OpenAIWithRawResponse:
281286
@cached_property
282287
def with_streaming_response(self) -> OpenAIWithStreamedResponse:
283288
return OpenAIWithStreamedResponse(self)
284-
285289
@property
286290
@override
287291
def qs(self) -> Querystring:
288292
return Querystring(array_format="brackets")
289293

294+
def refresh_auth_headers(self):
295+
bearer_token = self.bearer_token_provider() if self.bearer_token_provider else self.api_key
296+
self._auth_headers = {"Authorization": f"Bearer {bearer_token}"}
297+
298+
299+
@override
300+
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
301+
self.refresh_auth_headers()
302+
return super()._prepare_options(options)
303+
290304
@property
291305
@override
292306
def auth_headers(self) -> dict[str, str]:
293-
api_key = self.api_key
294-
if not api_key:
295-
# if the api key is an empty string, encoding the header will fail
296-
return {}
297-
return {"Authorization": f"Bearer {api_key}"}
298-
307+
return self._auth_headers
308+
299309
@property
300310
@override
301311
def default_headers(self) -> dict[str, str | Omit]:
@@ -420,6 +430,7 @@ def __init__(
420430
self,
421431
*,
422432
api_key: str | None = None,
433+
bearer_token_provider: Callable[[], Awaitable[str]] | None = None,
423434
organization: str | None = None,
424435
project: str | None = None,
425436
webhook_secret: str | None = None,
@@ -453,11 +464,12 @@ def __init__(
453464
"""
454465
if api_key is None:
455466
api_key = os.environ.get("OPENAI_API_KEY")
456-
if api_key is None:
467+
if api_key is None and bearer_token_provider is None:
457468
raise OpenAIError(
458469
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
459470
)
460-
self.api_key = api_key
471+
self.bearer_token_provider = bearer_token_provider
472+
self.api_key = api_key or ''
461473

462474
if organization is None:
463475
organization = os.environ.get("OPENAI_ORG_ID")
@@ -490,6 +502,7 @@ def __init__(
490502
)
491503

492504
self._default_stream_cls = AsyncStream
505+
self._auth_headers: dict[str, str] = {}
493506

494507
@cached_property
495508
def completions(self) -> AsyncCompletions:
@@ -612,14 +625,22 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
612625
def qs(self) -> Querystring:
613626
return Querystring(array_format="brackets")
614627

628+
async def refresh_auth_headers(self):
629+
if self.bearer_token_provider:
630+
bearer_token = await self.bearer_token_provider()
631+
else:
632+
bearer_token = self.api_key
633+
self._auth_headers = {"Authorization": f"Bearer {bearer_token}"}
634+
635+
@override
636+
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
637+
await self.refresh_auth_headers()
638+
return await super()._prepare_options(options)
639+
615640
@property
616641
@override
617642
def auth_headers(self) -> dict[str, str]:
618-
api_key = self.api_key
619-
if not api_key:
620-
# if the api key is an empty string, encoding the header will fail
621-
return {}
622-
return {"Authorization": f"Bearer {api_key}"}
643+
return self._auth_headers
623644

624645
@property
625646
@override

src/openai/lib/azure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
628628
"api-version": self._api_version,
629629
"deployment": self._azure_deployment or model,
630630
}
631-
if self.api_key != "<missing API key>":
631+
if self.api_key and self.api_key != "<missing API key>":
632632
auth_headers = {"api-key": self.api_key}
633633
else:
634634
token = await self._get_azure_ad_token()

src/openai/resources/beta/realtime/realtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
358358
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
359359

360360
extra_query = self.__extra_query
361+
await self.__client.refresh_auth_headers()
361362
auth_headers = self.__client.auth_headers
362363
if is_async_azure_client(self.__client):
363364
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
540541
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
541542

542543
extra_query = self.__extra_query
544+
self.__client.refresh_auth_headers()
543545
auth_headers = self.__client.auth_headers
544546
if is_azure_client(self.__client):
545547
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

0 commit comments

Comments
 (0)