Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions examples/azure_ad.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import asyncio

from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI, AzureADTokenProvider, AsyncAzureADTokenProvider

scopes = "https://cognitiveservices.azure.com/.default"

# May change in the future
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
api_version = "2023-07-01-preview"
from openai.lib.azure import OpenAI, AsyncOpenAI, AzureAuth, AsyncAzureAuth, AzureADTokenProvider, AsyncAzureADTokenProvider

# https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
endpoint = "https://my-resource.openai.azure.com"

deployment_name = "deployment-name" # e.g. gpt-35-instant
deployment_name = "deployment-name" # e.g. gpt-35-instant


def sync_main() -> None:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

token_provider: AzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), scopes)
token_provider: AzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), AzureAuth.DEFAULT_SCOPE)

client = AzureOpenAI(
api_version=api_version,
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider,
client = OpenAI(
base_url=endpoint,
auth_provider=AzureAuth(token_provider),
default_query={ # Temporary requirement to specify api version - will be removed once v1 routes go GA
'api-version': 'preview'
}
)

completion = client.chat.completions.create(
Expand All @@ -41,12 +37,14 @@ def sync_main() -> None:
async def async_main() -> None:
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider

token_provider: AsyncAzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), scopes)
token_provider: AsyncAzureADTokenProvider = get_bearer_token_provider(DefaultAzureCredential(), AsyncAzureAuth.DEFAULT_SCOPE)

client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider,
client = AsyncOpenAI(
base_url=endpoint,
auth_provider=AsyncAzureAuth(token_provider),
default_query={ # Temporary requirement to specify api version - will be removed once v1 routes go GA
'api-version': 'preview'
}
)

completion = await client.chat.completions.create(
Expand Down
32 changes: 30 additions & 2 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,19 @@
from . import types
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._client import (
Client,
OpenAI,
Stream,
Timeout,
Transport,
AsyncClient,
AsyncOpenAI,
AsyncStream,
AuthProvider,
AsyncAuthProvider,
RequestOptions,
)
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
Expand Down Expand Up @@ -72,6 +84,8 @@
"AsyncStream",
"OpenAI",
"AsyncOpenAI",
"AuthProvider",
"AsyncAuthProvider",
"file_from_path",
"BaseModel",
"DEFAULT_TIMEOUT",
Expand All @@ -87,7 +101,7 @@

from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
from .version import VERSION as VERSION
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI, AzureAuth as AzureAuth, AsyncAzureAuth as AsyncAzureAuth
from .lib._old_api import *
from .lib.streaming import (
AssistantEventHandler as AssistantEventHandler,
Expand Down Expand Up @@ -119,6 +133,8 @@

api_key: str | None = None

auth_provider: AuthProvider | None = None

organization: str | None = None

project: str | None = None
Expand Down Expand Up @@ -165,6 +181,17 @@ def api_key(self, value: str | None) -> None: # type: ignore

api_key = value

@property # type: ignore
@override
def auth_provider(self) -> AuthProvider | None:
return auth_provider

@auth_provider.setter # type: ignore
def auth_provider(self, value: AuthProvider | None) -> None: # type: ignore
global auth_provider

auth_provider = value

@property # type: ignore
@override
def organization(self) -> str | None:
Expand Down Expand Up @@ -348,6 +375,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]

_client = _ModuleClient(
api_key=api_key,
auth_provider=auth_provider,
organization=organization,
project=project,
webhook_secret=webhook_secret,
Expand Down
84 changes: 76 additions & 8 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing import TYPE_CHECKING, Any, Union, Mapping, Protocol
from typing_extensions import Self, override

import httpx
Expand All @@ -14,7 +14,9 @@
NOT_GIVEN,
Omit,
Timeout,
Headers,
NotGiven,
NotGivenOr,
Transport,
ProxiesTypes,
RequestOptions,
Expand All @@ -25,6 +27,7 @@
get_async_library,
)
from ._compat import cached_property
from ._models import FinalRequestOptions
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import OpenAIError, APIStatusError
Expand Down Expand Up @@ -73,6 +76,24 @@

__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "OpenAI", "AsyncOpenAI", "Client", "AsyncClient"]

class AuthProvider(Protocol):

def do_auth(self, *, url: httpx.___URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.___URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]:
"""Perform authentication for the request.

This method should be overridden by subclasses to implement specific authentication logic.
"""
raise NotImplementedError("Subclasses must implement this method.")

class AsyncAuthProvider(Protocol):

async def do_auth(self, *, url: httpx.___URL, headers: NotGivenOr[Headers] = NOT_GIVEN, params: NotGivenOr[dict[str, object]] = NOT_GIVEN, cookies: Any = NOT_GIVEN, response: httpx.Response | None = None) -> tuple[httpx.___URL, NotGivenOr[Headers], NotGivenOr[dict[str, object]], Any]:
"""Perform authentication for the request.

This method should be overridden by subclasses to implement specific authentication logic.
"""
raise NotImplementedError("Subclasses must implement this method.")


class OpenAI(SyncAPIClient):
# client options
Expand All @@ -93,6 +114,7 @@ def __init__(
self,
*,
api_key: str | None = None,
auth_provider: AuthProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -124,13 +146,16 @@ def __init__(
- `project` from `OPENAI_PROJECT_ID`
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
"""
if api_key and auth_provider:
raise ValueError("The `api_key` and `auth_provider` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and auth_provider is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
"The api_key or auth_provider client option must be set either by passing api_key or auth_provider to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.auth_provider = auth_provider
self.api_key = api_key or ""

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -279,6 +304,21 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
options = super()._prepare_options(options)

if self.auth_provider:
url, headers, params, _ = self.auth_provider.do_auth(
url = options.url,
headers = options.headers,
params = options.params
)
options.url = url
options.headers = headers
options.params = params
return options

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -287,7 +327,7 @@ def auth_headers(self) -> dict[str, str]:
# if the api key is an empty string, encoding the header will fail
return {}
return {"Authorization": f"Bearer {api_key}"}

@property
@override
def default_headers(self) -> dict[str, str | Omit]:
Expand All @@ -303,6 +343,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth_provider: AuthProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -338,6 +379,10 @@ def copy(
elif set_default_query is not None:
params = set_default_query

auth_provider = auth_provider or self.auth_provider
if auth_provider is not None:
_extra_kwargs = {**_extra_kwargs, "auth_provider": auth_provider}

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
Expand Down Expand Up @@ -412,6 +457,7 @@ def __init__(
self,
*,
api_key: str | None = None,
auth_provider: AsyncAuthProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -443,13 +489,16 @@ def __init__(
- `project` from `OPENAI_PROJECT_ID`
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
"""
if api_key and auth_provider:
raise ValueError("The `api_key` and `auth_provider` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and auth_provider is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
"The api_key or auth_provider client option must be set either by passing api_key or auth_provider to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.auth_provider = auth_provider
self.api_key = api_key or ""

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -598,6 +647,20 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
options = await super()._prepare_options(options)
if self.auth_provider:
url, headers, params, _ = await self.auth_provider.do_auth(
url = options.url,
headers = options.headers,
params = options.params
)
options.url = url
options.headers = headers
options.params = params
return options

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -622,6 +685,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth_provider: AsyncAuthProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -657,6 +721,10 @@ def copy(
elif set_default_query is not None:
params = set_default_query

auth_provider = auth_provider or self.auth_provider
if auth_provider is not None:
_extra_kwargs = {**_extra_kwargs, "auth_provider": auth_provider}

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
Expand Down
Loading