-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Add support for token refresh by accepting a callable api_key parameter value #2588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…low bearer tokens to be updated Allow api_key to be a callable to enable refresh of keys/tokens.
… Propagate bearer_token_provider in the `copy` method.
* add tests, fix copy, add token provider to module client * fix lint * ignore for azure copy * revert change
@@ -255,7 +255,7 @@ def __init__( | |||
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None | |||
|
|||
@override | |||
def copy( | |||
def copy( # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we not update the api_key
type here as well so that we don't need the type: ignore?
@@ -79,6 +81,7 @@ | |||
class OpenAI(SyncAPIClient): | |||
# client options | |||
api_key: str | |||
bearer_token_provider: Callable[[], str] | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be public, instead something like?
bearer_token_provider: Callable[[], str] | None = None | |
_api_key_provider: Callable[[], str] | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making it public allows the same users that read the api_key attribute to get an updated bearer token. But I also don't think that (beyond debugging) there are that many scenarios where you would want to do that. So I have no problems with making it private.
If I do follow that train of thought (developers are unlikely to read the api_key off the client in "real" production code) all the way through, I could make the argument that while it would technically be breaking to make the api_key property on the client a str | Callable[[], str]
, it would simplify things by having a single place to store what is currently mutually exclusive settings.
@@ -287,14 +296,28 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: | |||
def qs(self) -> Querystring: | |||
return Querystring(array_format="brackets") | |||
|
|||
def refresh_auth_headers(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a private method, I don't think we should expose this publicly.
@@ -287,14 +296,28 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: | |||
def qs(self) -> Querystring: | |||
return Querystring(array_format="brackets") | |||
|
|||
def refresh_auth_headers(self) -> None: | |||
if self.bearer_token_provider: | |||
secret = self.bearer_token_provider() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any fud around just setting self.api_key
directly and keeping def auth_headers()
mostly as-is?
this code feels prone to weird behaviour if a user ever mutates client.api_key
as we may use the version that's cached from self._auth_headers
instead. based on the current implementation that should basically be impossible to actually cause a bug as far as I can tell but I'd prefer to completely avoid that possibility.
def test_refresh_auth_headers_token(self) -> None: | ||
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token") | ||
client.refresh_auth_headers() | ||
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" | ||
|
||
def test_refresh_auth_headers_key(self) -> None: | ||
client = OpenAI(base_url=base_url, api_key="test_api_key") | ||
client.refresh_auth_headers() | ||
assert client.auth_headers.get("Authorization") == "Bearer test_api_key" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think these tests are super valuable as we're testing the runtime behaviour below.
a more useful test here would instead be what happens if you access client.auth_headers
before a request is made
Co-authored-by: Robert Craigie <[email protected]>
Changes being requested
Make the
api_key
parameter accept atyping. Callable[[], str]
(ortyping.Callable[[], typing.Awaitable[str]]
for the async client) to allow for dynamic token refresh.Additional context & links
This is what the python_ad.py example would look like when using the base OpenAI clients + the bearer token provider.