6
6
import threading
7
7
import time
8
8
import warnings
9
+ from contextlib import asynccontextmanager
9
10
from json import JSONDecodeError
10
11
from typing import (
11
- AsyncContextManager ,
12
12
AsyncGenerator ,
13
+ AsyncIterator ,
13
14
Callable ,
14
15
Dict ,
15
16
Iterator ,
@@ -367,9 +368,8 @@ async def arequest(
367
368
request_id : Optional [str ] = None ,
368
369
request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
369
370
) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
370
- ctx = AioHTTPSession ()
371
+ ctx = aiohttp_session ()
371
372
session = await ctx .__aenter__ ()
372
- result = None
373
373
try :
374
374
result = await self .arequest_raw (
375
375
method .lower (),
@@ -383,9 +383,6 @@ async def arequest(
383
383
)
384
384
resp , got_stream = await self ._interpret_async_response (result , stream )
385
385
except Exception :
386
- # Close the request before exiting session context.
387
- if result is not None :
388
- result .release ()
389
386
await ctx .__aexit__ (None , None , None )
390
387
raise
391
388
if got_stream :
@@ -396,15 +393,10 @@ async def wrap_resp():
396
393
async for r in resp :
397
394
yield r
398
395
finally :
399
- # Close the request before exiting session context. Important to do it here
400
- # as if stream is not fully exhausted, we need to close the request nevertheless.
401
- result .release ()
402
396
await ctx .__aexit__ (None , None , None )
403
397
404
398
return wrap_resp (), got_stream , self .api_key
405
399
else :
406
- # Close the request before exiting session context.
407
- result .release ()
408
400
await ctx .__aexit__ (None , None , None )
409
401
return resp , got_stream , self .api_key
410
402
@@ -778,22 +770,11 @@ def _interpret_response_line(
778
770
return resp
779
771
780
772
781
- class AioHTTPSession (AsyncContextManager ):
782
- def __init__ (self ):
783
- self ._session = None
784
- self ._should_close_session = False
785
-
786
- async def __aenter__ (self ):
787
- self ._session = openai .aiosession .get ()
788
- if self ._session is None :
789
- self ._session = await aiohttp .ClientSession ().__aenter__ ()
790
- self ._should_close_session = True
791
-
792
- return self ._session
793
-
794
- async def __aexit__ (self , exc_type , exc_value , traceback ):
795
- if self ._session is None :
796
- raise RuntimeError ("Session is not initialized" )
797
-
798
- if self ._should_close_session :
799
- await self ._session .__aexit__ (exc_type , exc_value , traceback )
773
+ @asynccontextmanager
774
+ async def aiohttp_session () -> AsyncIterator [aiohttp .ClientSession ]:
775
+ user_set_session = openai .aiosession .get ()
776
+ if user_set_session :
777
+ yield user_set_session
778
+ else :
779
+ async with aiohttp .ClientSession () as session :
780
+ yield session
0 commit comments