Skip to content

Commit 60dba27

Browse files
committed
Review feedback + make sure you can swap between callable and string api key values for module level client.
1 parent 1ffd959 commit 60dba27

File tree

6 files changed

+98
-73
lines changed

6 files changed

+98
-73
lines changed

src/openai/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,26 @@ class _ModuleClient(OpenAI):
157157
@property # type: ignore
158158
@override
159159
def api_key(self) -> str | _t.Callable[[], str] | None:
160-
return api_key
160+
return api_key() if callable(api_key) else api_key
161161

162162
@api_key.setter # type: ignore
163163
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
164164
global api_key
165165
api_key = value
166166

167+
@property
168+
def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore
169+
return None
167170

168-
@property # type: ignore
171+
@_api_key_provider.setter
172+
def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore
173+
global api_key
174+
# Yes, setting the api_key is intentional. The module level client accepts callables
175+
# for the module level api_key and will call it to retrieve the value
176+
# if it is a callable.
177+
api_key = value
178+
179+
@property
169180
@override
170181
def organization(self) -> str | None:
171182
return organization

src/openai/_client.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
class OpenAI(SyncAPIClient):
8282
# client options
8383
api_key: str
84-
bearer_token_provider: Callable[[], str] | None = None
8584
organization: str | None
8685
project: str | None
8786
webhook_secret: str | None
@@ -137,10 +136,10 @@ def __init__(
137136
)
138137
if callable(api_key):
139138
self.api_key = ""
140-
self.bearer_token_provider = api_key
139+
self._api_key_provider: Callable[[], str] | None = api_key
141140
else:
142141
self.api_key = api_key or ""
143-
self.bearer_token_provider = None
142+
self._api_key_provider = None
144143

145144
if organization is None:
146145
organization = os.environ.get("OPENAI_ORG_ID")
@@ -173,7 +172,6 @@ def __init__(
173172
)
174173

175174
self._default_stream_cls = Stream
176-
self._auth_headers: dict[str, str] = {}
177175

178176
@cached_property
179177
def completions(self) -> Completions:
@@ -296,28 +294,23 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
296294
def qs(self) -> Querystring:
297295
return Querystring(array_format="brackets")
298296

299-
def refresh_auth_headers(self) -> None:
300-
if self.bearer_token_provider:
301-
secret = self.bearer_token_provider()
302-
else:
303-
secret = self.api_key
304-
if not secret:
305-
# if the api key is an empty string, encoding the header will fail
306-
# so we set it to an empty dict
307-
# this is to avoid sending an invalid Authorization header
308-
self._auth_headers = {}
309-
else:
310-
self._auth_headers = {"Authorization": f"Bearer {secret}"}
297+
def _refresh_api_key(self) -> None:
298+
if self._api_key_provider:
299+
self.api_key = self._api_key_provider()
311300

312301
@override
313302
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
314-
self.refresh_auth_headers()
303+
self._refresh_api_key()
315304
return super()._prepare_options(options)
316305

317306
@property
318307
@override
319308
def auth_headers(self) -> dict[str, str]:
320-
return self._auth_headers
309+
api_key = self.api_key
310+
if not api_key:
311+
# if the api key is an empty string, encoding the header will fail
312+
return {}
313+
return {"Authorization": f"Bearer {api_key}"}
321314

322315
@property
323316
@override
@@ -371,7 +364,7 @@ def copy(
371364

372365
http_client = http_client or self._client
373366
return self.__class__(
374-
api_key=api_key or self.bearer_token_provider or self.api_key,
367+
api_key=api_key or self._api_key_provider or self.api_key,
375368
organization=organization or self.organization,
376369
project=project or self.project,
377370
webhook_secret=webhook_secret or self.webhook_secret,
@@ -427,7 +420,6 @@ def _make_status_error(
427420
class AsyncOpenAI(AsyncAPIClient):
428421
# client options
429422
api_key: str
430-
bearer_token_provider: Callable[[], Awaitable[str]] | None = None
431423
organization: str | None
432424
project: str | None
433425
webhook_secret: str | None
@@ -483,10 +475,10 @@ def __init__(
483475
)
484476
if callable(api_key):
485477
self.api_key = ""
486-
self.bearer_token_provider = api_key
478+
self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
487479
else:
488480
self.api_key = api_key or ""
489-
self.bearer_token_provider = None
481+
self._api_key_provider = None
490482

491483
if organization is None:
492484
organization = os.environ.get("OPENAI_ORG_ID")
@@ -519,7 +511,6 @@ def __init__(
519511
)
520512

521513
self._default_stream_cls = AsyncStream
522-
self._auth_headers: dict[str, str] = {}
523514

524515
@cached_property
525516
def completions(self) -> AsyncCompletions:
@@ -642,28 +633,23 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
642633
def qs(self) -> Querystring:
643634
return Querystring(array_format="brackets")
644635

645-
async def refresh_auth_headers(self) -> None:
646-
if self.bearer_token_provider:
647-
secret = await self.bearer_token_provider()
648-
else:
649-
secret = self.api_key
650-
if not secret:
651-
# if the api key is an empty string, encoding the header will fail
652-
# so we set it to an empty dict
653-
# this is to avoid sending an invalid Authorization header
654-
self._auth_headers = {}
655-
else:
656-
self._auth_headers = {"Authorization": f"Bearer {secret}"}
636+
async def _refresh_api_key(self) -> None:
637+
if self._api_key_provider:
638+
self.api_key = await self._api_key_provider()
657639

658640
@override
659641
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
660-
await self.refresh_auth_headers()
642+
await self._refresh_api_key()
661643
return await super()._prepare_options(options)
662644

663645
@property
664646
@override
665647
def auth_headers(self) -> dict[str, str]:
666-
return self._auth_headers
648+
api_key = self.api_key
649+
if not api_key:
650+
# if the api key is an empty string, encoding the header will fail
651+
return {}
652+
return {"Authorization": f"Bearer {api_key}"}
667653

668654
@property
669655
@override
@@ -717,7 +703,7 @@ def copy(
717703

718704
http_client = http_client or self._client
719705
return self.__class__(
720-
api_key=api_key or self.bearer_token_provider or self.api_key,
706+
api_key=api_key or self._api_key_provider or self.api_key,
721707
organization=organization or self.organization,
722708
project=project or self.project,
723709
webhook_secret=webhook_secret or self.webhook_secret,

src/openai/lib/azure.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def __init__(
255255
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
256256

257257
@override
258-
def copy( # type: ignore
258+
def copy(
259259
self,
260260
*,
261-
api_key: str | None = None,
261+
api_key: str | Callable[[], str] | None = None,
262262
organization: str | None = None,
263263
project: str | None = None,
264264
webhook_secret: str | None = None,
@@ -301,7 +301,7 @@ def copy( # type: ignore
301301
},
302302
)
303303

304-
with_options = copy # type: ignore
304+
with_options = copy
305305

306306
def _get_azure_ad_token(self) -> str | None:
307307
if self._azure_ad_token is not None:
@@ -435,7 +435,7 @@ def __init__(
435435
azure_endpoint: str | None = None,
436436
azure_deployment: str | None = None,
437437
api_version: str | None = None,
438-
api_key: str | None = None,
438+
api_key: str | Callable[[], Awaitable[str]] | None = None,
439439
azure_ad_token: str | None = None,
440440
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
441441
organization: str | None = None,
@@ -536,10 +536,10 @@ def __init__(
536536
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
537537

538538
@override
539-
def copy( # type: ignore
539+
def copy(
540540
self,
541541
*,
542-
api_key: str | None = None,
542+
api_key: str | Callable[[], Awaitable[str]] | None = None,
543543
organization: str | None = None,
544544
project: str | None = None,
545545
webhook_secret: str | None = None,
@@ -582,7 +582,7 @@ def copy( # type: ignore
582582
},
583583
)
584584

585-
with_options = copy # type: ignore
585+
with_options = copy
586586

587587
async def _get_azure_ad_token(self) -> str | None:
588588
if self._azure_ad_token is not None:

src/openai/resources/beta/realtime/realtime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
358358
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
359359

360360
extra_query = self.__extra_query
361-
await self.__client.refresh_auth_headers()
361+
await self.__client._refresh_api_key()
362362
auth_headers = self.__client.auth_headers
363363
if is_async_azure_client(self.__client):
364364
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -541,7 +541,7 @@ def __enter__(self) -> RealtimeConnection:
541541
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
542542

543543
extra_query = self.__extra_query
544-
self.__client.refresh_auth_headers()
544+
self.__client._refresh_api_key()
545545
auth_headers = self.__client.auth_headers
546546
if is_azure_client(self.__client):
547547
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

tests/test_client.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -945,14 +945,24 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
945945
assert exc_info.value.response.status_code == 302
946946
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
947947

948-
def test_refresh_auth_headers_token(self) -> None:
948+
def test_api_key_before_after_refresh_provider(self) -> None:
949949
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
950-
client.refresh_auth_headers()
950+
951+
assert client.api_key == ""
952+
assert 'Authorization' not in client.auth_headers
953+
954+
client._refresh_api_key()
955+
956+
assert client.api_key == "test_bearer_token"
951957
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
952958

953-
def test_refresh_auth_headers_key(self) -> None:
959+
960+
def test_api_key_before_after_refresh_str(self) -> None:
954961
client = OpenAI(base_url=base_url, api_key="test_api_key")
955-
client.refresh_auth_headers()
962+
963+
assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
964+
client._refresh_api_key()
965+
956966
assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
957967

958968
@pytest.mark.respx()
@@ -985,12 +995,11 @@ def token_provider() -> str:
985995
assert calls[0].request.headers.get("Authorization") == "Bearer first"
986996
assert calls[1].request.headers.get("Authorization") == "Bearer second"
987997

988-
989998
def test_copy_auth(self) -> None:
990999
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
9911000
api_key=lambda: "test_bearer_token_2"
9921001
)
993-
client.refresh_auth_headers()
1002+
client._refresh_api_key()
9941003
assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
9951004

9961005

@@ -1944,18 +1953,28 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
19441953
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
19451954

19461955
@pytest.mark.asyncio
1947-
async def test_refresh_auth_headers_token_async(self) -> None:
1948-
async def token_provider() -> str:
1956+
async def test_api_key_before_after_refresh_provider(self) -> None:
1957+
async def mock_api_key_provider():
19491958
return "test_bearer_token"
1959+
1960+
client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
19501961

1951-
client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
1952-
await client.refresh_auth_headers()
1962+
assert client.api_key == ""
1963+
assert 'Authorization' not in client.auth_headers
1964+
1965+
await client._refresh_api_key()
1966+
1967+
assert client.api_key == "test_bearer_token"
19531968
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
19541969

1970+
19551971
@pytest.mark.asyncio
1956-
async def test_refresh_auth_headers_key_async(self) -> None:
1972+
async def test_api_key_before_after_refresh_str(self) -> None:
19571973
client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
1958-
await client.refresh_auth_headers()
1974+
1975+
assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1976+
await client._refresh_api_key()
1977+
19591978
assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
19601979

19611980
@pytest.mark.asyncio
@@ -1997,8 +2016,6 @@ async def token_provider_1() -> str:
19972016
async def token_provider_2() -> str:
19982017
return "test_bearer_token_2"
19992018

2000-
client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(
2001-
api_key=token_provider_2
2002-
)
2003-
await client.refresh_auth_headers()
2019+
client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2020+
await client._refresh_api_key()
20042021
assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}

tests/test_module_client.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,22 @@ def test_http_client_option() -> None:
9797
assert openai.completions._client._client is new_client
9898

9999

100-
def test_bearer_token_provider_option() -> None:
101-
openai.api_key = lambda: "foo"
100+
def test_api_key_callable() -> None:
101+
openai.api_key = lambda: "1"
102+
assert openai.completions._client.api_key == "1"
102103

103-
assert openai.completions._client.bearer_token_provider
104-
assert openai.completions._client.bearer_token_provider() == "foo"
104+
def test_api_key_overridable() -> None:
105+
openai.api_key = lambda: "1"
106+
assert openai.completions._client.api_key == "1"
107+
assert openai.completions._client._api_key_provider is None
105108

109+
openai.api_key = "2"
110+
assert openai.completions._client.api_key == "2"
111+
assert openai.completions._client._api_key_provider is None
112+
113+
openai.api_key = lambda: "3"
114+
assert openai.completions._client.api_key == "3"
115+
assert openai.completions._client._api_key_provider is None
106116

107117
import contextlib
108118
from typing import Iterator
@@ -130,22 +140,23 @@ def test_only_api_key_results_in_openai_api() -> None:
130140
assert type(openai.completions._client).__name__ == "_ModuleClient"
131141

132142

133-
def test_only_bearer_token_provider_in_openai_api() -> None:
143+
def test_only_api_key_in_openai_api() -> None:
134144
with fresh_env():
135145
openai.api_type = None
136146
openai.api_key = lambda: "example bearer token"
137147

138148
assert type(openai.completions._client).__name__ == "_ModuleClient"
139149

140150

141-
def test_both_api_key_and_bearer_token_provider_in_openai_api() -> None:
151+
def test_both_api_key_and_api_key_provider_in_openai_api() -> None:
142152
with fresh_env():
143153
openai.api_key = lambda: "example bearer token"
144154

145-
assert(openai.api_key() == "example bearer token")
155+
assert openai.api_key() == "example bearer token"
146156

147157
openai.api_key = "example API key"
148-
assert(openai.api_key == "example API key")
158+
assert openai.api_key == "example API key"
159+
149160

150161
def test_azure_api_key_env_without_api_version() -> None:
151162
with fresh_env():

0 commit comments

Comments
 (0)