Skip to content

Commit a6a6a2b

Browse files
committed
Make api_key callable to enable token refresh for openai client
1 parent 7873247 commit a6a6a2b

File tree

4 files changed

+42
-96
lines changed

4 files changed

+42
-96
lines changed

src/openai/__init__.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,7 @@
117117

118118
from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
119119

120-
api_key: str | None = None
121-
122-
bearer_token_provider: _t.Callable[[], str] | None = None
120+
api_key: str | _t.Callable[[], str] | None = None
123121

124122
organization: str | None = None
125123

@@ -158,25 +156,14 @@ class _ModuleClient(OpenAI):
158156

159157
@property # type: ignore
160158
@override
161-
def api_key(self) -> str | None:
159+
def api_key(self) -> str | _t.Callable[[], str] | None:
162160
return api_key
163161

164162
@api_key.setter # type: ignore
165-
def api_key(self, value: str | None) -> None: # type: ignore
163+
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
166164
global api_key
167-
168165
api_key = value
169166

170-
@property # type: ignore
171-
@override
172-
def bearer_token_provider(self) -> _t.Callable[[], str] | None:
173-
return bearer_token_provider
174-
175-
@bearer_token_provider.setter # type: ignore
176-
def bearer_token_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore
177-
global bearer_token_provider
178-
179-
bearer_token_provider = value
180167

181168
@property # type: ignore
182169
@override
@@ -346,7 +333,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
346333
_client = _AzureModuleClient( # type: ignore
347334
api_version=api_version,
348335
azure_endpoint=azure_endpoint,
349-
api_key=api_key,
336+
api_key=bearer_token_provider or api_key,
350337
azure_ad_token=azure_ad_token,
351338
azure_ad_token_provider=azure_ad_token_provider,
352339
organization=organization,
@@ -360,8 +347,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
360347
return _client
361348

362349
_client = _ModuleClient(
363-
api_key=api_key,
364-
bearer_token_provider=bearer_token_provider,
350+
api_key=api_key or bearer_token_provider,
365351
organization=organization,
366352
project=project,
367353
webhook_secret=webhook_secret,

src/openai/_client.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
class OpenAI(SyncAPIClient):
8282
# client options
8383
api_key: str
84+
bearer_token_provider: Callable[[], str] | None = None
8485
organization: str | None
8586
project: str | None
8687
webhook_secret: str | None
@@ -96,8 +97,7 @@ class OpenAI(SyncAPIClient):
9697
def __init__(
9798
self,
9899
*,
99-
api_key: str | None = None,
100-
bearer_token_provider: Callable[[], str] | None = None,
100+
api_key: str | None | Callable[[], str] = None,
101101
organization: str | None = None,
102102
project: str | None = None,
103103
webhook_secret: str | None = None,
@@ -129,16 +129,17 @@ def __init__(
129129
- `project` from `OPENAI_PROJECT_ID`
130130
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
131131
"""
132-
if api_key and bearer_token_provider:
133-
raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive")
134132
if api_key is None:
135133
api_key = os.environ.get("OPENAI_API_KEY")
136-
if api_key is None and bearer_token_provider is None:
134+
if api_key is None:
137135
raise OpenAIError(
138-
"The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable"
136+
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
139137
)
140-
self.bearer_token_provider = bearer_token_provider
141-
self.api_key = api_key or ""
138+
if callable(api_key):
139+
self.bearer_token_provider = api_key
140+
self.api_key = ""
141+
else:
142+
self.api_key = api_key or ""
142143

143144
if organization is None:
144145
organization = os.environ.get("OPENAI_ORG_ID")
@@ -328,8 +329,7 @@ def default_headers(self) -> dict[str, str | Omit]:
328329
def copy(
329330
self,
330331
*,
331-
api_key: str | None = None,
332-
bearer_token_provider: Callable[[], str] | None = None,
332+
api_key: str | Callable[[], str] | None = None,
333333
organization: str | None = None,
334334
project: str | None = None,
335335
webhook_secret: str | None = None,
@@ -365,13 +365,9 @@ def copy(
365365
elif set_default_query is not None:
366366
params = set_default_query
367367

368-
bearer_token_provider = bearer_token_provider or self.bearer_token_provider
369-
if bearer_token_provider is not None:
370-
_extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider}
371-
372368
http_client = http_client or self._client
373369
return self.__class__(
374-
api_key=api_key or self.api_key,
370+
api_key=api_key or self.api_key or self.bearer_token_provider,
375371
organization=organization or self.organization,
376372
project=project or self.project,
377373
webhook_secret=webhook_secret or self.webhook_secret,
@@ -427,6 +423,7 @@ def _make_status_error(
427423
class AsyncOpenAI(AsyncAPIClient):
428424
# client options
429425
api_key: str
426+
bearer_token_provider: Callable[[], Awaitable[str]] | None = None
430427
organization: str | None
431428
project: str | None
432429
webhook_secret: str | None
@@ -442,8 +439,7 @@ class AsyncOpenAI(AsyncAPIClient):
442439
def __init__(
443440
self,
444441
*,
445-
api_key: str | None = None,
446-
bearer_token_provider: Callable[[], Awaitable[str]] | None = None,
442+
api_key: str | Callable[[], Awaitable[str]] | None = None,
447443
organization: str | None = None,
448444
project: str | None = None,
449445
webhook_secret: str | None = None,
@@ -475,16 +471,18 @@ def __init__(
475471
- `project` from `OPENAI_PROJECT_ID`
476472
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
477473
"""
478-
if api_key and bearer_token_provider:
479-
raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive")
480474
if api_key is None:
481475
api_key = os.environ.get("OPENAI_API_KEY")
482-
if api_key is None and bearer_token_provider is None:
476+
if api_key is None:
483477
raise OpenAIError(
484-
"The api_key or bearer_token_provider client option must be set either by passing api_key or bearer_token_provider to the client or by setting the OPENAI_API_KEY environment variable"
478+
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
485479
)
486-
self.bearer_token_provider = bearer_token_provider
487-
self.api_key = api_key or ""
480+
if callable(api_key):
481+
self.bearer_token_provider = api_key
482+
self.api_key = ""
483+
else:
484+
self.bearer_token_provider = None
485+
self.api_key = api_key or ""
488486

489487
if organization is None:
490488
organization = os.environ.get("OPENAI_ORG_ID")
@@ -677,8 +675,7 @@ def default_headers(self) -> dict[str, str | Omit]:
677675
def copy(
678676
self,
679677
*,
680-
api_key: str | None = None,
681-
bearer_token_provider: Callable[[], Awaitable[str]] | None = None,
678+
api_key: str | Callable[[], Awaitable[str]] | None = None,
682679
organization: str | None = None,
683680
project: str | None = None,
684681
webhook_secret: str | None = None,
@@ -714,13 +711,9 @@ def copy(
714711
elif set_default_query is not None:
715712
params = set_default_query
716713

717-
bearer_token_provider = bearer_token_provider or self.bearer_token_provider
718-
if bearer_token_provider is not None:
719-
_extra_kwargs = {**_extra_kwargs, "bearer_token_provider": bearer_token_provider}
720-
721714
http_client = http_client or self._client
722715
return self.__class__(
723-
api_key=api_key or self.api_key,
716+
api_key=api_key or self.api_key or self.bearer_token_provider,
724717
organization=organization or self.organization,
725718
project=project or self.project,
726719
webhook_secret=webhook_secret or self.webhook_secret,

tests/test_client.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
946946
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
947947

948948
def test_refresh_auth_headers_token(self) -> None:
949-
client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token")
949+
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
950950
client.refresh_auth_headers()
951951
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
952952

@@ -976,7 +976,7 @@ def token_provider() -> str:
976976

977977
return "second"
978978

979-
client = OpenAI(base_url=base_url, bearer_token_provider=token_provider)
979+
client = OpenAI(base_url=base_url, api_key=token_provider)
980980
client.chat.completions.create(messages=[], model="gpt-4")
981981

982982
calls = cast("list[MockRequestCall]", respx_mock.calls)
@@ -985,23 +985,14 @@ def token_provider() -> str:
985985
assert calls[0].request.headers.get("Authorization") == "Bearer first"
986986
assert calls[1].request.headers.get("Authorization") == "Bearer second"
987987

988-
def test_auth_mutually_exclusive(self) -> None:
989-
with pytest.raises(ValueError) as exc_info:
990-
OpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=lambda: "test_bearer_token")
991-
assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive"
992988

993989
def test_copy_auth(self) -> None:
994-
client = OpenAI(base_url=base_url, bearer_token_provider=lambda: "test_bearer_token_1").copy(
995-
bearer_token_provider=lambda: "test_bearer_token_2"
990+
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
991+
api_key=lambda: "test_bearer_token_2"
996992
)
997993
client.refresh_auth_headers()
998994
assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
999995

1000-
def test_copy_auth_mutually_exclusive(self) -> None:
1001-
with pytest.raises(ValueError) as exc_info:
1002-
OpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=lambda: "test_bearer_token")
1003-
assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive"
1004-
1005996

1006997
class TestAsyncOpenAI:
1007998
client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1957,7 +1948,7 @@ async def test_refresh_auth_headers_token_async(self) -> None:
19571948
async def token_provider() -> str:
19581949
return "test_bearer_token"
19591950

1960-
client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider)
1951+
client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
19611952
await client.refresh_auth_headers()
19621953
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
19631954

@@ -1989,7 +1980,7 @@ async def token_provider() -> str:
19891980

19901981
return "second"
19911982

1992-
client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider)
1983+
client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
19931984
await client.chat.completions.create(messages=[], model="gpt-4")
19941985

19951986
calls = cast("list[MockRequestCall]", respx_mock.calls)
@@ -1998,14 +1989,6 @@ async def token_provider() -> str:
19981989
assert calls[0].request.headers.get("Authorization") == "Bearer first"
19991990
assert calls[1].request.headers.get("Authorization") == "Bearer second"
20001991

2001-
def test_auth_mutually_exclusive_async(self) -> None:
2002-
async def token_provider() -> str:
2003-
return "test_bearer_token"
2004-
2005-
with pytest.raises(ValueError) as exc_info:
2006-
AsyncOpenAI(base_url=base_url, api_key=api_key, bearer_token_provider=token_provider)
2007-
assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive"
2008-
20091992
@pytest.mark.asyncio
20101993
async def test_copy_auth(self) -> None:
20111994
async def token_provider_1() -> str:
@@ -2014,16 +1997,8 @@ async def token_provider_1() -> str:
20141997
async def token_provider_2() -> str:
20151998
return "test_bearer_token_2"
20161999

2017-
client = AsyncOpenAI(base_url=base_url, bearer_token_provider=token_provider_1).copy(
2018-
bearer_token_provider=token_provider_2
2000+
client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(
2001+
api_key=token_provider_2
20192002
)
20202003
await client.refresh_auth_headers()
20212004
assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2022-
2023-
def test_copy_auth_mutually_exclusive_async(self) -> None:
2024-
async def token_provider() -> str:
2025-
return "test_bearer_token"
2026-
2027-
with pytest.raises(ValueError) as exc_info:
2028-
AsyncOpenAI(base_url=base_url, api_key=api_key).copy(bearer_token_provider=token_provider)
2029-
assert str(exc_info.value) == "The `api_key` and `bearer_token_provider` arguments are mutually exclusive"

tests/test_module_client.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,8 @@ def test_http_client_option() -> None:
9999

100100

101101
def test_bearer_token_provider_option() -> None:
102-
assert openai.bearer_token_provider is None
103-
assert openai.completions._client.bearer_token_provider is None
102+
openai.api_key = lambda: "foo"
104103

105-
openai.bearer_token_provider = lambda: "foo"
106-
107-
assert openai.bearer_token_provider() == "foo"
108104
assert openai.completions._client.bearer_token_provider
109105
assert openai.completions._client.bearer_token_provider() == "foo"
110106

@@ -138,23 +134,19 @@ def test_only_api_key_results_in_openai_api() -> None:
138134
def test_only_bearer_token_provider_in_openai_api() -> None:
139135
with fresh_env():
140136
openai.api_type = None
141-
openai.api_key = None
142-
openai.bearer_token_provider = lambda: "example bearer token"
137+
openai.api_key = lambda: "example bearer token"
143138

144139
assert type(openai.completions._client).__name__ == "_ModuleClient"
145140

146141

147142
def test_both_api_key_and_bearer_token_provider_in_openai_api() -> None:
148143
with fresh_env():
149-
openai.api_key = "example API key"
150-
openai.bearer_token_provider = lambda: "example bearer token"
144+
openai.api_key = lambda: "example bearer token"
151145

152-
with pytest.raises(
153-
ValueError,
154-
match=r"The `api_key` and `bearer_token_provider` arguments are mutually exclusive",
155-
):
156-
openai.completions._client # noqa: B018
146+
assert(openai.api_key() == "example bearer token")
157147

148+
openai.api_key = "example API key"
149+
assert(openai.api_key == "example API key")
158150

159151
def test_azure_api_key_env_without_api_version() -> None:
160152
with fresh_env():

0 commit comments

Comments
 (0)