|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import os
|
6 |
| -from typing import TYPE_CHECKING, Any, Union, Mapping |
| 6 | +from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable |
7 | 7 | from typing_extensions import Self, override
|
8 | 8 |
|
9 | 9 | import httpx
|
10 | 10 |
|
| 11 | +from openai._models import FinalRequestOptions |
| 12 | + |
11 | 13 | from . import _exceptions
|
12 | 14 | from ._qs import Querystring
|
13 | 15 | from ._types import (
|
@@ -95,6 +97,7 @@ def __init__(
|
95 | 97 | self,
|
96 | 98 | *,
|
97 | 99 | api_key: str | None = None,
|
| 100 | + bearer_token_provider: Callable[[], str] | None = None, |
98 | 101 | organization: str | None = None,
|
99 | 102 | project: str | None = None,
|
100 | 103 | webhook_secret: str | None = None,
|
@@ -128,11 +131,12 @@ def __init__(
|
128 | 131 | """
|
129 | 132 | if api_key is None:
|
130 | 133 | api_key = os.environ.get("OPENAI_API_KEY")
|
131 |
| - if api_key is None: |
| 134 | + if api_key is None and bearer_token_provider is None: |
132 | 135 | raise OpenAIError(
|
133 | 136 | "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
134 | 137 | )
|
135 |
| - self.api_key = api_key |
| 138 | + self.bearer_token_provider = bearer_token_provider |
| 139 | + self.api_key = api_key or '' |
136 | 140 |
|
137 | 141 | if organization is None:
|
138 | 142 | organization = os.environ.get("OPENAI_ORG_ID")
|
@@ -165,6 +169,7 @@ def __init__(
|
165 | 169 | )
|
166 | 170 |
|
167 | 171 | self._default_stream_cls = Stream
|
| 172 | + self._auth_headers: dict[str, str] = {} |
168 | 173 |
|
169 | 174 | @cached_property
|
170 | 175 | def completions(self) -> Completions:
|
@@ -281,21 +286,26 @@ def with_raw_response(self) -> OpenAIWithRawResponse:
|
281 | 286 | @cached_property
|
282 | 287 | def with_streaming_response(self) -> OpenAIWithStreamedResponse:
|
283 | 288 | return OpenAIWithStreamedResponse(self)
|
284 |
| - |
285 | 289 | @property
|
286 | 290 | @override
|
287 | 291 | def qs(self) -> Querystring:
|
288 | 292 | return Querystring(array_format="brackets")
|
289 | 293 |
|
| 294 | + def refresh_auth_headers(self): |
| 295 | + bearer_token = self.bearer_token_provider() if self.bearer_token_provider else self.api_key |
| 296 | + self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} |
| 297 | + |
| 298 | + |
| 299 | + @override |
| 300 | + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: |
| 301 | + self.refresh_auth_headers() |
| 302 | + return super()._prepare_options(options) |
| 303 | + |
290 | 304 | @property
|
291 | 305 | @override
|
292 | 306 | def auth_headers(self) -> dict[str, str]:
|
293 |
| - api_key = self.api_key |
294 |
| - if not api_key: |
295 |
| - # if the api key is an empty string, encoding the header will fail |
296 |
| - return {} |
297 |
| - return {"Authorization": f"Bearer {api_key}"} |
298 |
| - |
| 307 | + return self._auth_headers |
| 308 | + |
299 | 309 | @property
|
300 | 310 | @override
|
301 | 311 | def default_headers(self) -> dict[str, str | Omit]:
|
@@ -420,6 +430,7 @@ def __init__(
|
420 | 430 | self,
|
421 | 431 | *,
|
422 | 432 | api_key: str | None = None,
|
| 433 | + bearer_token_provider: Callable[[], Awaitable[str]] | None = None, |
423 | 434 | organization: str | None = None,
|
424 | 435 | project: str | None = None,
|
425 | 436 | webhook_secret: str | None = None,
|
@@ -453,11 +464,12 @@ def __init__(
|
453 | 464 | """
|
454 | 465 | if api_key is None:
|
455 | 466 | api_key = os.environ.get("OPENAI_API_KEY")
|
456 |
| - if api_key is None: |
| 467 | + if api_key is None and bearer_token_provider is None: |
457 | 468 | raise OpenAIError(
|
458 | 469 | "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
459 | 470 | )
|
460 |
| - self.api_key = api_key |
| 471 | + self.bearer_token_provider = bearer_token_provider |
| 472 | + self.api_key = api_key or '' |
461 | 473 |
|
462 | 474 | if organization is None:
|
463 | 475 | organization = os.environ.get("OPENAI_ORG_ID")
|
@@ -490,6 +502,7 @@ def __init__(
|
490 | 502 | )
|
491 | 503 |
|
492 | 504 | self._default_stream_cls = AsyncStream
|
| 505 | + self._auth_headers: dict[str, str] = {} |
493 | 506 |
|
494 | 507 | @cached_property
|
495 | 508 | def completions(self) -> AsyncCompletions:
|
@@ -612,14 +625,22 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
|
612 | 625 | def qs(self) -> Querystring:
|
613 | 626 | return Querystring(array_format="brackets")
|
614 | 627 |
|
| 628 | + async def refresh_auth_headers(self): |
| 629 | + if self.bearer_token_provider: |
| 630 | + bearer_token = await self.bearer_token_provider() |
| 631 | + else: |
| 632 | + bearer_token = self.api_key |
| 633 | + self._auth_headers = {"Authorization": f"Bearer {bearer_token}"} |
| 634 | + |
| 635 | + @override |
| 636 | + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: |
| 637 | + await self.refresh_auth_headers() |
| 638 | + return await super()._prepare_options(options) |
| 639 | + |
615 | 640 | @property
|
616 | 641 | @override
|
617 | 642 | def auth_headers(self) -> dict[str, str]:
|
618 |
| - api_key = self.api_key |
619 |
| - if not api_key: |
620 |
| - # if the api key is an empty string, encoding the header will fail |
621 |
| - return {} |
622 |
| - return {"Authorization": f"Bearer {api_key}"} |
| 643 | + return self._auth_headers |
623 | 644 |
|
624 | 645 | @property
|
625 | 646 | @override
|
|
0 commit comments