3
3
import datetime
4
4
from os import environ
5
5
import re
6
- from typing import Awaitable
6
+ from typing import Annotated , Awaitable , Generic , TypeVar , TypedDict
7
7
from traceback import print_exc
8
8
9
9
from pydantic import BaseModel
12
12
import sqlalchemy .ext .asyncio
13
13
from sqlalchemy import select , update
14
14
from sqlalchemy .orm import Mapped , mapped_column
15
- from fastapi import FastAPI , HTTPException
15
+ from fastapi import FastAPI , HTTPException , Path , Query
16
16
from aiohttp import ClientSession
17
17
18
+ KeyT = TypeVar ("KeyT" )
19
+ ValueT = TypeVar ("ValueT" )
20
+
21
+ # LRUCache that has sections is accessed asynchonously
22
+ class LRUCache (Generic [KeyT , ValueT ]):
23
+ def __init__ (self , max_size : int = 1024 ):
24
+ self .max_size = max_size
25
+ self ._cache = {}
26
+ self ._lru = []
27
+ self ._lock = asyncio .Lock ()
28
+
29
+ async def get (self , key : KeyT ) -> ValueT | None :
30
+ async with self ._lock :
31
+ if key in self ._cache :
32
+ self ._lru .remove (key )
33
+ self ._lru .append (key )
34
+ return self ._cache [key ]
35
+ else :
36
+ return None
37
+
38
+ async def set (self , key : KeyT , value : ValueT ) -> None :
39
+ async with self ._lock :
40
+ if key in self ._cache :
41
+ self ._lru .remove (key )
42
+ elif len (self ._cache ) >= self .max_size :
43
+ del self ._cache [self ._lru .pop (0 )]
44
+ self ._cache [key ] = value
45
+ self ._lru .append (key )
46
+
47
+ async def delete (self , key : KeyT ) -> None :
48
+ async with self ._lock :
49
+ if key in self ._cache :
50
+ del self ._cache [key ]
51
+ self ._lru .remove (key )
52
+
53
+ async def prune (self , keep : int | None = 128 ) -> None :
54
+ async with self ._lock :
55
+ if keep is None :
56
+ self ._cache = {}
57
+ self ._lru = []
58
+ else :
59
+ self ._lru = self ._lru [- keep :]
60
+ self ._cache = {key : self ._cache [key ] for key in self ._lru }
61
+
62
+ class LRUCacheConsumers (TypedDict ):
63
+ batch_job_count : LRUCache [tuple [int , datetime .datetime | None ], int ] # Key is (batch_id, after)
64
+ job_count : LRUCache [datetime .datetime | None , int ] # Key is after
65
+ batch_count : LRUCache [datetime .datetime | None , int ] # Key is after
66
+ url_count : LRUCache [datetime .datetime | None , int ] # Key is after
67
+ repeat_url_count : LRUCache [datetime .datetime | None , int ] # Key is after
68
+
69
+
18
70
engine : sqlalchemy .ext .asyncio .AsyncEngine = None
19
71
async_session : sqlalchemy .ext .asyncio .async_sessionmaker [sqlalchemy .ext .asyncio .AsyncSession ] = None
20
72
client_session : ClientSession = None
73
+ cache : LRUCacheConsumers = None
21
74
22
75
class Base (sqlalchemy .orm .MappedAsDataclass , sqlalchemy .ext .asyncio .AsyncAttrs , sqlalchemy .orm .DeclarativeBase ):
23
76
pass
@@ -26,7 +79,7 @@ class Batch(Base):
26
79
__tablename__ = "batches"
27
80
28
81
id : Mapped [int ] = mapped_column (primary_key = True , autoincrement = True , init = False )
29
- created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False )
82
+ created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False , index = True )
30
83
31
84
jobs : Mapped [list ["Job" ]] = sqlalchemy .orm .relationship ("Job" , secondary = "batch_jobs" , back_populates = "batches" , init = False , repr = False )
32
85
@@ -35,17 +88,17 @@ class URL(Base):
35
88
36
89
id : Mapped [int ] = mapped_column (primary_key = True , autoincrement = True , init = False )
37
90
url : Mapped [str ] = mapped_column (sqlalchemy .String (length = 10000 ), unique = True , index = True )
38
- first_seen : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False )
91
+ first_seen : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False , index = True )
39
92
jobs : Mapped [list ["Job" ]] = sqlalchemy .orm .relationship ("Job" , back_populates = "url" , init = False , repr = False )
40
- last_seen : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True )
93
+ last_seen : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True , index = True )
41
94
42
95
class RepeatURL (Base ):
43
96
__tablename__ = "repeat_urls"
44
97
45
98
id : Mapped [int ] = mapped_column (primary_key = True , autoincrement = True , init = False )
46
99
url_id : Mapped [int ] = mapped_column (sqlalchemy .ForeignKey (URL .id ), nullable = False , unique = True , init = False )
47
100
url : Mapped [URL ] = sqlalchemy .orm .relationship (URL , lazy = "joined" , innerjoin = True , foreign_keys = [url_id ])
48
- created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False )
101
+ created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False , index = True )
49
102
batch_id : Mapped [int ] = mapped_column (sqlalchemy .ForeignKey (Batch .id ), nullable = False , unique = True , init = False )
50
103
batch : Mapped [Batch ] = sqlalchemy .orm .relationship (Batch , lazy = "joined" , innerjoin = True , foreign_keys = [batch_id ])
51
104
interval : Mapped [int ] = mapped_column (default = 3600 * 4 )
@@ -66,17 +119,17 @@ class Job(Base):
66
119
batches : Mapped [list [Batch ]] = sqlalchemy .orm .relationship (Batch , secondary = "batch_jobs" , back_populates = "jobs" )
67
120
url_id : Mapped [int ] = mapped_column (sqlalchemy .ForeignKey (URL .id ), nullable = False , init = False , repr = False )
68
121
url : Mapped [URL ] = sqlalchemy .orm .relationship (URL , lazy = "joined" , innerjoin = True , foreign_keys = [url_id ], back_populates = "jobs" )
69
- created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False )
70
- completed : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True )
71
- delayed_until : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True ) # If a job needs to be delayed, this is the time it should be run at
122
+ created_at : Mapped [datetime .datetime ] = mapped_column (sqlalchemy .DateTime (timezone = True ), server_default = sqlalchemy .sql .func .now (), nullable = False , init = False , index = True )
123
+ completed : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True , index = True )
124
+ delayed_until : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True , index = True ) # If a job needs to be delayed, this is the time it should be run at
72
125
priority : Mapped [int ] = mapped_column (default = 0 )
73
126
retry : Mapped [int ] = mapped_column (sqlalchemy .SmallInteger , default = 0 ) # Number of times this job has been retried
74
- failed : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True ) # If a job has failed, this is the time it failed at
127
+ failed : Mapped [datetime .datetime | None ] = mapped_column (sqlalchemy .DateTime (timezone = True ), default = None , nullable = True , index = True ) # If a job has failed, this is the time it failed at
75
128
76
129
async def get_current_job (curtime = None , * , session : sqlalchemy .ext .asyncio .AsyncSession | None = None , get_batches : bool = False ) -> Job | None :
77
130
if not curtime :
78
131
curtime = datetime .datetime .now (tz = datetime .timezone .utc )
79
- stmt = select (Job ).where (((Job .delayed_until <= curtime ) | (Job .delayed_until == None )) & (Job .completed == None ) & (Job .failed == None )).order_by (Job .priority .desc (), Job .created_at ).limit (1 )
132
+ stmt = select (Job ).where (((Job .delayed_until <= curtime ) | (Job .delayed_until == None )) & (Job .completed == None ) & (Job .failed == None )).order_by (Job .priority .desc (), Job .retry . desc (), Job . id ).limit (1 )
80
133
if get_batches :
81
134
stmt = stmt .options (sqlalchemy .orm .joinedload (Job .batches ))
82
135
if session is None :
@@ -142,7 +195,7 @@ async def repeat_url_worker():
142
195
while True :
143
196
curtime = datetime .datetime .now (tz = datetime .timezone .utc )
144
197
async with async_session () as session , session .begin ():
145
- stmt = select (RepeatURL ).where (RepeatURL .active_since <= curtime ).order_by (RepeatURL .created_at )
198
+ stmt = select (RepeatURL ).where (RepeatURL .active_since <= curtime ).order_by (RepeatURL .id )
146
199
result = await session .scalars (stmt )
147
200
jobs = result .all ()
148
201
stmt2 = select (URL .url ).join (Job ).where (URL .url .in_ ([job .url .url for job in jobs ]) & (Job .completed == None ) & (Job .failed == None ))
@@ -165,14 +218,15 @@ async def repeat_url_worker():
165
218
166
219
@asynccontextmanager
167
220
async def lifespan (_ : FastAPI ):
168
- global engine , async_session
221
+ global engine , async_session , cache
169
222
if engine is None :
170
223
engine = sqlalchemy .ext .asyncio .create_async_engine (environ .get ("DATABASE_URL" , "sqlite:///db.sqlite" ))
171
224
async_session = sqlalchemy .ext .asyncio .async_sessionmaker (engine , expire_on_commit = False )
172
225
async with engine .begin () as conn :
173
226
await conn .run_sync (Base .metadata .create_all )
174
227
workers .append (asyncio .create_task (exception_logger (url_worker (), name = "url_worker" )))
175
228
workers .append (asyncio .create_task (exception_logger (repeat_url_worker (), name = "repeat_url_worker" )))
229
+ cache = {key : LRUCache () for key in LRUCacheConsumers .__required_keys__ }
176
230
try :
177
231
yield
178
232
finally :
@@ -351,7 +405,13 @@ class JobReturn(BaseModel):
351
405
batches : list [int ] = []
352
406
353
407
@classmethod
354
- def from_job (cls , job : Job ):
408
+ def from_job (cls , job : Job , batch_ids : list [int ] | None = None ):
409
+ """Make a JobReturn from a Job
410
+
411
+ :param job: The job to make a JobReturn from
412
+ :param batch_ids: Provide a list of batch IDs when a joinedload was not used
413
+ :return: A JobReturn
414
+ """
355
415
return cls (
356
416
id = job .id ,
357
417
url = job .url .url ,
@@ -361,7 +421,7 @@ def from_job(cls, job: Job):
361
421
priority = job .priority ,
362
422
retry = job .retry ,
363
423
failed = job .failed ,
364
- batches = [batch .id for batch in job .batches ]
424
+ batches = batch_ids if batch_ids is not None else [batch .id for batch in job .batches ]
365
425
)
366
426
367
427
class CurrentJobReturn (BaseModel ):
@@ -377,31 +437,127 @@ async def current_job() -> CurrentJobReturn:
377
437
}
378
438
379
439
@app .get ("/job/{job_id}" )
380
- async def get_job (job_id : int ) -> JobReturn :
440
+ async def get_job (job_id : Annotated [ int , Path ( title = "Job ID" , description = "The ID of the job you want info on" , ge = 1 )] ) -> JobReturn :
381
441
async with async_session () as session , session .begin ():
382
- stmt = select (Job ).where (Job .id == job_id ).limit (1 )
442
+ stmt = select (Job ).where (Job .id == job_id ).limit (1 ). options ( sqlalchemy . orm . joinedload ( Job . batches ))
383
443
job = await session .scalar (stmt )
384
444
if job is None :
385
445
raise HTTPException (status_code = 404 , detail = "Job not found" )
386
446
return JobReturn .from_job (job )
387
447
388
448
class BatchReturn (BaseModel ):
389
449
id : int
450
+ """The ID of the batch"""
390
451
created_at : datetime .datetime
391
- jobs : list [JobReturn ] = []
452
+ """The time the batch was created"""
453
+ repeat_url : int | None = None
454
+ """The ID of the repeat URL that the batch represents, if any"""
455
+ jobs : int
456
+ """The number of jobs in the batch"""
392
457
393
458
@app .get ("/batch/{batch_id}" )
394
- async def get_batch (batch_id : int ) :
459
+ async def get_batch (batch_id : Annotated [ int , Path ( title = "Batch ID" , description = "The ID of the batch you want info on" , ge = 1 )]) -> BatchReturn :
395
460
async with async_session () as session , session .begin ():
396
461
stmt = select (Batch ).where (Batch .id == batch_id ).limit (1 )
397
462
batch = await session .scalar (stmt )
398
463
if batch is None :
399
464
raise HTTPException (status_code = 404 , detail = "Batch not found" )
465
+ stmt = select (RepeatURL .id ).join (RepeatURL .batch ).where (Batch .id == batch_id ).limit (1 )
466
+ repeat_url = await session .scalar (stmt )
467
+ job_count = await session .scalar (select (sqlalchemy .func .count (Job .id )).join (Batch .jobs ).where (Batch .id == batch_id ))
400
468
return {
401
469
"id" : batch .id ,
402
470
"created_at" : batch .created_at ,
403
- "jobs" : [JobReturn .from_job (job ) for job in batch .jobs ]
471
+ "repeat_url" : repeat_url ,
472
+ "jobs" : job_count
404
473
}
474
+
475
+ class PaginationInfo (BaseModel ):
476
+ current_page : int
477
+ total_pages : int
478
+ items : int
479
+
480
+ ModelT = TypeVar ("ModelT" , bound = BaseModel )
481
+
482
+ class PaginationOutput (BaseModel , Generic [ModelT ]):
483
+ data : list [ModelT ]
484
+ pagination : PaginationInfo
485
+
486
+ Page = Annotated [int , Query (title = "Page" , description = "The page of results you want" , ge = 1 , le = 100 )]
487
+
488
+ @app .get ("/jobs" )
489
+ async def get_jobs (page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [JobReturn ]:
490
+ job_count_cache = cache ["job_count" ]
491
+ cache_key = after
492
+ async with async_session () as session , session .begin ():
493
+ if (job_count := await job_count_cache .get (cache_key )) is None :
494
+ stmt = select (sqlalchemy .func .count (Job .id ))
495
+ if after :
496
+ stmt = stmt .where (Job .created_at > after )
497
+ job_count = await session .scalar (stmt )
498
+ await job_count_cache .set (cache_key , job_count )
499
+ stmt = select (Job ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
500
+ if after :
501
+ stmt = stmt .where (Job .created_at > after )
502
+ result = await session .scalars (stmt )
503
+ return PaginationOutput (
504
+ data = [JobReturn .from_job (job ) for job in result .unique ().all ()],
505
+ pagination = PaginationInfo (
506
+ current_page = page ,
507
+ total_pages = job_count // 100 + 1 ,
508
+ items = job_count
509
+ )
510
+ )
511
+
512
+ @app .get ("/batch/{batch_id}/jobs" )
513
+ async def get_batch_jobs (batch_id : Annotated [int , Path (title = "Batch ID" , description = "The ID of the batch you want info on" , ge = 1 )], page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [JobReturn ]:
514
+ job_count_cache = cache ["batch_job_count" ]
515
+ cache_key = (batch_id , after )
516
+ async with async_session () as session , session .begin ():
517
+ if (job_count := await job_count_cache .get (cache_key )) is None :
518
+ stmt = select (sqlalchemy .func .count (Job .id )).join (Batch .jobs ).where (Batch .id == batch_id )
519
+ if after :
520
+ stmt = stmt .where (Job .created_at > after )
521
+ job_count = await session .scalar (stmt )
522
+ await job_count_cache .set (cache_key , job_count )
523
+ stmt = select (Job ).join (Batch .jobs ).where (Batch .id == batch_id ).order_by (Job .id .desc () if desc else Job .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Job .batches ))
524
+ if after :
525
+ stmt = stmt .where (Job .created_at > after )
526
+ result = await session .scalars (stmt )
527
+ return PaginationOutput (
528
+ data = [JobReturn .from_job (job ) for job in result .unique ().all ()],
529
+ pagination = PaginationInfo (
530
+ current_page = page ,
531
+ total_pages = job_count // 100 + 1 ,
532
+ items = job_count
533
+ )
534
+ )
535
+
536
+ @app .get ("/batches" )
537
+ async def get_batches (page : Page = 1 , after : datetime .datetime | None = None , desc : bool = False ) -> PaginationOutput [BatchReturn ]:
538
+ batch_count_cache = cache ["batch_count" ]
539
+ cache_key = after
540
+ async with async_session () as session , session .begin ():
541
+ if (batch_count := await batch_count_cache .get (cache_key )) is None :
542
+ stmt = select (sqlalchemy .func .count (Batch .id ))
543
+ if after :
544
+ stmt = stmt .where (Batch .created_at > after )
545
+ batch_count = await session .scalar (stmt )
546
+ await batch_count_cache .set (cache_key , batch_count )
547
+ stmt = select (Batch ).order_by (Batch .id .desc () if desc else Batch .id .asc ()).offset ((page - 1 ) * 100 ).limit (100 ).options (sqlalchemy .orm .joinedload (Batch .jobs ))
548
+ if after :
549
+ stmt = stmt .where (Batch .created_at > after )
550
+ result = await session .scalars (stmt )
551
+ return PaginationOutput (
552
+ data = [BatchReturn (id = batch .id , created_at = batch .created_at , jobs = len (batch .jobs )) for batch in result .unique ().all ()],
553
+ pagination = PaginationInfo (
554
+ current_page = page ,
555
+ total_pages = batch_count // 100 + 1 ,
556
+ items = batch_count
557
+ )
558
+ )
559
+
560
+
405
561
class URLInfoBody (BaseModel ):
406
562
url : str
407
563
@@ -413,7 +569,7 @@ class URLReturn(BaseModel):
413
569
@app .post ("/url" )
414
570
async def get_url_info (body : URLInfoBody ) -> URLReturn :
415
571
async with async_session () as session , session .begin ():
416
- stmt = select (URL ).where (URL .url == body .url ).limit (1 )
572
+ stmt = select (URL ).where (URL .url == body .url ).limit (1 ). options ( sqlalchemy . orm . joinedload ( URL . jobs )). options ( sqlalchemy . orm . joinedload ( URL . jobs , Job . batches ))
417
573
url = await session .scalar (stmt )
418
574
if url is None :
419
575
raise HTTPException (status_code = 404 , detail = "URL not found" )
0 commit comments