3
3
import datetime
4
4
from os import environ
5
5
import re
6
- from typing import Annotated , Awaitable , Generic , TypeVar , TypedDict
6
+ from typing import Annotated , Awaitable , Generic , TypeVar
7
7
from traceback import print_exc
8
8
9
9
from pydantic import BaseModel
15
15
from fastapi import FastAPI , HTTPException , Path , Query
16
16
from aiohttp import ClientSession
17
17
18
- KeyT = TypeVar ("KeyT" )
19
- ValueT = TypeVar ("ValueT" )
20
-
21
- # LRUCache that has sections is accessed asynchonously
22
- class LRUCache (Generic [KeyT , ValueT ]):
23
- def __init__ (self , max_size : int = 1024 ):
24
- self .max_size = max_size
25
- self ._cache = {}
26
- self ._lru = []
27
- self ._lock = asyncio .Lock ()
28
-
29
- async def get (self , key : KeyT ) -> ValueT | None :
30
- async with self ._lock :
31
- if key in self ._cache :
32
- self ._lru .remove (key )
33
- self ._lru .append (key )
34
- return self ._cache [key ]
35
- else :
36
- return None
37
-
38
- async def set (self , key : KeyT , value : ValueT ) -> None :
39
- async with self ._lock :
40
- if key in self ._cache :
41
- self ._lru .remove (key )
42
- elif len (self ._cache ) >= self .max_size :
43
- del self ._cache [self ._lru .pop (0 )]
44
- self ._cache [key ] = value
45
- self ._lru .append (key )
46
-
47
- async def delete (self , key : KeyT ) -> None :
48
- async with self ._lock :
49
- if key in self ._cache :
50
- del self ._cache [key ]
51
- self ._lru .remove (key )
52
-
53
- async def prune (self , keep : int | None = 128 ) -> None :
54
- async with self ._lock :
55
- if keep is None :
56
- self ._cache = {}
57
- self ._lru = []
58
- else :
59
- self ._lru = self ._lru [- keep :]
60
- self ._cache = {key : self ._cache [key ] for key in self ._lru }
61
-
62
- class LRUCacheConsumers (TypedDict ):
63
- batch_job_count : LRUCache [tuple [int , datetime .datetime | None ], int ] # Key is (batch_id, after)
64
- job_count : LRUCache [datetime .datetime | None , int ] # Key is after
65
- batch_count : LRUCache [datetime .datetime | None , int ] # Key is after
66
- url_count : LRUCache [datetime .datetime | None , int ] # Key is after
67
- repeat_url_count : LRUCache [datetime .datetime | None , int ] # Key is after
68
-
69
-
70
18
engine : sqlalchemy .ext .asyncio .AsyncEngine = None
71
19
async_session : sqlalchemy .ext .asyncio .async_sessionmaker [sqlalchemy .ext .asyncio .AsyncSession ] = None
72
20
client_session : ClientSession = None
73
- cache : LRUCacheConsumers = None
74
-
75
21
class Base (sqlalchemy .orm .MappedAsDataclass , sqlalchemy .ext .asyncio .AsyncAttrs , sqlalchemy .orm .DeclarativeBase ):
76
22
pass
77
23
@@ -226,7 +172,6 @@ async def lifespan(_: FastAPI):
226
172
await conn .run_sync (Base .metadata .create_all )
227
173
workers .append (asyncio .create_task (exception_logger (url_worker (), name = "url_worker" )))
228
174
workers .append (asyncio .create_task (exception_logger (repeat_url_worker (), name = "repeat_url_worker" )))
229
- cache = {key : LRUCache () for key in LRUCacheConsumers .__required_keys__ }
230
175
try :
231
176
yield
232
177
finally :
@@ -487,19 +432,15 @@ class PaginationOutput(BaseModel, Generic[ModelT]):
487
432
488
433
@app .get ("/jobs" )
489
434
async def get_jobs (page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [JobReturn ]:
490
- job_count_cache = cache ["job_count" ]
491
- cache_key = after
492
435
async with async_session () as session , session .begin ():
493
- if (job_count := await job_count_cache .get (cache_key )) is None :
494
- stmt = select (sqlalchemy .func .count (Job .id ))
495
- if after :
496
- stmt = stmt .where (Job .created_at > after )
497
- job_count = await session .scalar (stmt )
498
- await job_count_cache .set (cache_key , job_count )
499
- stmt = select (Job ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
436
+ stmt = select (sqlalchemy .func .count (Job .id ))
500
437
if after :
501
438
stmt = stmt .where (Job .created_at > after )
502
- result = await session .scalars (stmt )
439
+ job_count = await session .scalar (stmt )
440
+ stmt2 = select (Job ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
441
+ if after :
442
+ stmt2 = stmt2 .where (Job .created_at > after )
443
+ result = await session .scalars (stmt2 )
503
444
return PaginationOutput (
504
445
data = [JobReturn .from_job (job ) for job in result .unique ().all ()],
505
446
pagination = PaginationInfo (
@@ -511,19 +452,15 @@ async def get_jobs(page: Page = 1, after: datetime.datetime | None = None, desc:
511
452
512
453
@app .get ("/batch/{batch_id}/jobs" )
513
454
async def get_batch_jobs (batch_id : Annotated [int , Path (title = "Batch ID" , description = "The ID of the batch you want info on" , ge = 1 )], page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [JobReturn ]:
514
- job_count_cache = cache ["batch_job_count" ]
515
- cache_key = (batch_id , after )
516
455
async with async_session () as session , session .begin ():
517
- if (job_count := await job_count_cache .get (cache_key )) is None :
518
- stmt = select (sqlalchemy .func .count (Job .id )).join (Batch .jobs ).where (Batch .id == batch_id )
519
- if after :
520
- stmt = stmt .where (Job .created_at > after )
521
- job_count = await session .scalar (stmt )
522
- await job_count_cache .set (cache_key , job_count )
523
- stmt = select (Job ).join (Batch .jobs ).where (Batch .id == batch_id ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
456
+ stmt = select (sqlalchemy .func .count (Job .id )).join (Batch .jobs ).where (Batch .id == batch_id )
524
457
if after :
525
458
stmt = stmt .where (Job .created_at > after )
526
- result = await session .scalars (stmt )
459
+ job_count = await session .scalar (stmt )
460
+ stmt2 = select (Job ).join (Batch .jobs ).where (Batch .id == batch_id ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
461
+ if after :
462
+ stmt2 = stmt2 .where (Job .created_at > after )
463
+ result = await session .scalars (stmt2 )
527
464
return PaginationOutput (
528
465
data = [JobReturn .from_job (job ) for job in result .unique ().all ()],
529
466
pagination = PaginationInfo (
@@ -535,19 +472,15 @@ async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", descrip
535
472
536
473
@app .get ("/batches" )
537
474
async def get_batches (page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [BatchReturn ]:
538
- batch_count_cache = cache ["batch_count" ]
539
- cache_key = after
540
475
async with async_session () as session , session .begin ():
541
- if (batch_count := await batch_count_cache .get (cache_key )) is None :
542
- stmt = select (sqlalchemy .func .count (Batch .id ))
543
- if after :
544
- stmt = stmt .where (Batch .created_at > after )
545
- batch_count = await session .scalar (stmt )
546
- await batch_count_cache .set (cache_key , batch_count )
547
- stmt = select (Batch ).order_by (Batch .id .desc () if desc else Batch .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Batch .jobs ))
476
+ stmt = select (sqlalchemy .func .count (Batch .id ))
548
477
if after :
549
478
stmt = stmt .where (Batch .created_at > after )
550
- result = await session .scalars (stmt )
479
+ batch_count = await session .scalar (stmt )
480
+ stmt2 = select (Batch ).order_by (Batch .id .desc () if desc else Batch .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Batch .jobs ))
481
+ if after :
482
+ stmt2 = stmt2 .where (Batch .created_at > after )
483
+ result = await session .scalars (stmt2 )
551
484
return PaginationOutput (
552
485
data = [BatchReturn (id = batch .id , created_at = batch .created_at , jobs = len (batch .jobs )) for batch in result .unique ().all ()],
553
486
pagination = PaginationInfo (
0 commit comments