Skip to content
21 changes: 16 additions & 5 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@

from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES

api_key: str | None = None
api_key: str | _t.Callable[[], str] | None = None

organization: str | None = None

Expand Down Expand Up @@ -156,16 +156,27 @@ class _ModuleClient(OpenAI):

@property # type: ignore
@override
def api_key(self) -> str | None:
return api_key
def api_key(self) -> str | _t.Callable[[], str] | None:
return api_key() if callable(api_key) else api_key

@api_key.setter # type: ignore
def api_key(self, value: str | None) -> None: # type: ignore
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
global api_key
api_key = value

@property
def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore
return None

@_api_key_provider.setter
def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore
global api_key
# Yes, setting the api_key is intentional. The module level client accepts callables
# for the module level api_key and will call it to retrieve the value
# if it is a callable.
api_key = value

@property # type: ignore
@property
@override
def organization(self) -> str | None:
return organization
Expand Down
48 changes: 39 additions & 9 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
from typing_extensions import Self, override

import httpx

from openai._models import FinalRequestOptions

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -94,7 +96,7 @@ class OpenAI(SyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | None | Callable[[], str] = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -132,7 +134,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self._api_key_provider: Callable[[], str] | None = api_key
else:
self.api_key = api_key or ""
self._api_key_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -287,6 +294,15 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

def _refresh_api_key(self) -> None:
if self._api_key_provider:
self.api_key = self._api_key_provider()

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self._refresh_api_key()
return super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -310,7 +326,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -348,7 +364,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down Expand Up @@ -419,7 +435,7 @@ class AsyncOpenAI(AsyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -457,7 +473,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
else:
self.api_key = api_key or ""
self._api_key_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -612,6 +633,15 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

async def _refresh_api_key(self) -> None:
if self._api_key_provider:
self.api_key = await self._api_key_provider()

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self._refresh_api_key()
return await super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -635,7 +665,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -673,7 +703,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down
8 changes: 4 additions & 4 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -435,7 +435,7 @@ def __init__(
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
Expand Down Expand Up @@ -539,7 +539,7 @@ def __init__(
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
Expand Down
2 changes: 2 additions & 0 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
await self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
Expand Down Expand Up @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
Expand Down
Loading