Skip to content

Commit 7f614bc

Browse files
committed
Add a lot more endpoints
1 parent 9d83076 commit 7f614bc

File tree

1 file changed

+177
-21
lines changed

1 file changed

+177
-21
lines changed

src/main.py

Lines changed: 177 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
from os import environ
55
import re
6-
from typing import Awaitable
6+
from typing import Annotated, Awaitable, Generic, TypeVar, TypedDict
77
from traceback import print_exc
88

99
from pydantic import BaseModel
@@ -12,12 +12,65 @@
1212
import sqlalchemy.ext.asyncio
1313
from sqlalchemy import select, update
1414
from sqlalchemy.orm import Mapped, mapped_column
15-
from fastapi import FastAPI, HTTPException
15+
from fastapi import FastAPI, HTTPException, Path, Query
1616
from aiohttp import ClientSession
1717

18+
KeyT = TypeVar("KeyT")
19+
ValueT = TypeVar("ValueT")
20+
21+
# LRUCache that has sections is accessed asynchonously
22+
class LRUCache(Generic[KeyT, ValueT]):
23+
def __init__(self, max_size: int = 1024):
24+
self.max_size = max_size
25+
self._cache = {}
26+
self._lru = []
27+
self._lock = asyncio.Lock()
28+
29+
async def get(self, key: KeyT) -> ValueT | None:
30+
async with self._lock:
31+
if key in self._cache:
32+
self._lru.remove(key)
33+
self._lru.append(key)
34+
return self._cache[key]
35+
else:
36+
return None
37+
38+
async def set(self, key: KeyT, value: ValueT) -> None:
39+
async with self._lock:
40+
if key in self._cache:
41+
self._lru.remove(key)
42+
elif len(self._cache) >= self.max_size:
43+
del self._cache[self._lru.pop(0)]
44+
self._cache[key] = value
45+
self._lru.append(key)
46+
47+
async def delete(self, key: KeyT) -> None:
48+
async with self._lock:
49+
if key in self._cache:
50+
del self._cache[key]
51+
self._lru.remove(key)
52+
53+
async def prune(self, keep: int | None = 128) -> None:
54+
async with self._lock:
55+
if keep is None:
56+
self._cache = {}
57+
self._lru = []
58+
else:
59+
self._lru = self._lru[-keep:]
60+
self._cache = {key: self._cache[key] for key in self._lru}
61+
62+
class LRUCacheConsumers(TypedDict):
63+
batch_job_count: LRUCache[tuple[int, datetime.datetime | None], int] # Key is (batch_id, after)
64+
job_count: LRUCache[datetime.datetime | None, int] # Key is after
65+
batch_count: LRUCache[datetime.datetime | None, int] # Key is after
66+
url_count: LRUCache[datetime.datetime | None, int] # Key is after
67+
repeat_url_count: LRUCache[datetime.datetime | None, int] # Key is after
68+
69+
1870
engine: sqlalchemy.ext.asyncio.AsyncEngine = None
1971
async_session: sqlalchemy.ext.asyncio.async_sessionmaker[sqlalchemy.ext.asyncio.AsyncSession] = None
2072
client_session: ClientSession = None
73+
cache: LRUCacheConsumers = None
2174

2275
class Base(sqlalchemy.orm.MappedAsDataclass, sqlalchemy.ext.asyncio.AsyncAttrs, sqlalchemy.orm.DeclarativeBase):
2376
pass
@@ -26,7 +79,7 @@ class Batch(Base):
2679
__tablename__ = "batches"
2780

2881
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
29-
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False)
82+
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False, index=True)
3083

3184
jobs: Mapped[list["Job"]] = sqlalchemy.orm.relationship("Job", secondary="batch_jobs", back_populates="batches", init=False, repr=False)
3285

@@ -35,17 +88,17 @@ class URL(Base):
3588

3689
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
3790
url: Mapped[str] = mapped_column(sqlalchemy.String(length=10000), unique=True, index=True)
38-
first_seen: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False)
91+
first_seen: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False, index=True)
3992
jobs: Mapped[list["Job"]] = sqlalchemy.orm.relationship("Job", back_populates="url", init=False, repr=False)
40-
last_seen: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True)
93+
last_seen: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True)
4194

4295
class RepeatURL(Base):
4396
__tablename__ = "repeat_urls"
4497

4598
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
4699
url_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(URL.id), nullable=False, unique=True, init=False)
47100
url: Mapped[URL] = sqlalchemy.orm.relationship(URL, lazy="joined", innerjoin=True, foreign_keys=[url_id])
48-
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False)
101+
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False, index=True)
49102
batch_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(Batch.id), nullable=False, unique=True, init=False)
50103
batch: Mapped[Batch] = sqlalchemy.orm.relationship(Batch, lazy="joined", innerjoin=True, foreign_keys=[batch_id])
51104
interval: Mapped[int] = mapped_column(default=3600 * 4)
@@ -66,17 +119,17 @@ class Job(Base):
66119
batches: Mapped[list[Batch]] = sqlalchemy.orm.relationship(Batch, secondary="batch_jobs", back_populates="jobs")
67120
url_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(URL.id), nullable=False, init=False, repr=False)
68121
url: Mapped[URL] = sqlalchemy.orm.relationship(URL, lazy="joined", innerjoin=True, foreign_keys=[url_id], back_populates="jobs")
69-
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False)
70-
completed: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True)
71-
delayed_until: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True) # If a job needs to be delayed, this is the time it should be run at
122+
created_at: Mapped[datetime.datetime] = mapped_column(sqlalchemy.DateTime(timezone=True), server_default=sqlalchemy.sql.func.now(), nullable=False, init=False, index=True)
123+
completed: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True)
124+
delayed_until: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True) # If a job needs to be delayed, this is the time it should be run at
72125
priority: Mapped[int] = mapped_column(default=0)
73126
retry: Mapped[int] = mapped_column(sqlalchemy.SmallInteger, default=0) # Number of times this job has been retried
74-
failed: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True) # If a job has failed, this is the time it failed at
127+
failed: Mapped[datetime.datetime | None] = mapped_column(sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True) # If a job has failed, this is the time it failed at
75128

76129
async def get_current_job(curtime=None, *, session: sqlalchemy.ext.asyncio.AsyncSession | None = None, get_batches: bool = False) -> Job | None:
77130
if not curtime:
78131
curtime = datetime.datetime.now(tz=datetime.timezone.utc)
79-
stmt = select(Job).where(((Job.delayed_until <= curtime) | (Job.delayed_until == None)) & (Job.completed == None) & (Job.failed == None)).order_by(Job.priority.desc(), Job.created_at).limit(1)
132+
stmt = select(Job).where(((Job.delayed_until <= curtime) | (Job.delayed_until == None)) & (Job.completed == None) & (Job.failed == None)).order_by(Job.priority.desc(), Job.retry.desc(), Job.id).limit(1)
80133
if get_batches:
81134
stmt = stmt.options(sqlalchemy.orm.joinedload(Job.batches))
82135
if session is None:
@@ -142,7 +195,7 @@ async def repeat_url_worker():
142195
while True:
143196
curtime = datetime.datetime.now(tz=datetime.timezone.utc)
144197
async with async_session() as session, session.begin():
145-
stmt = select(RepeatURL).where(RepeatURL.active_since <= curtime).order_by(RepeatURL.created_at)
198+
stmt = select(RepeatURL).where(RepeatURL.active_since <= curtime).order_by(RepeatURL.id)
146199
result = await session.scalars(stmt)
147200
jobs = result.all()
148201
stmt2 = select(URL.url).join(Job).where(URL.url.in_([job.url.url for job in jobs]) & (Job.completed == None) & (Job.failed == None))
@@ -165,14 +218,15 @@ async def repeat_url_worker():
165218

166219
@asynccontextmanager
167220
async def lifespan(_: FastAPI):
168-
global engine, async_session
221+
global engine, async_session, cache
169222
if engine is None:
170223
engine = sqlalchemy.ext.asyncio.create_async_engine(environ.get("DATABASE_URL", "sqlite:///db.sqlite"))
171224
async_session = sqlalchemy.ext.asyncio.async_sessionmaker(engine, expire_on_commit=False)
172225
async with engine.begin() as conn:
173226
await conn.run_sync(Base.metadata.create_all)
174227
workers.append(asyncio.create_task(exception_logger(url_worker(), name="url_worker")))
175228
workers.append(asyncio.create_task(exception_logger(repeat_url_worker(), name="repeat_url_worker")))
229+
cache = {key: LRUCache() for key in LRUCacheConsumers.__required_keys__}
176230
try:
177231
yield
178232
finally:
@@ -351,7 +405,13 @@ class JobReturn(BaseModel):
351405
batches: list[int] = []
352406

353407
@classmethod
354-
def from_job(cls, job: Job):
408+
def from_job(cls, job: Job, batch_ids: list[int] | None = None):
409+
"""Make a JobReturn from a Job
410+
411+
:param job: The job to make a JobReturn from
412+
:param batch_ids: Provide a list of batch IDs when a joinedload was not used
413+
:return: A JobReturn
414+
"""
355415
return cls(
356416
id=job.id,
357417
url=job.url.url,
@@ -361,7 +421,7 @@ def from_job(cls, job: Job):
361421
priority=job.priority,
362422
retry=job.retry,
363423
failed=job.failed,
364-
batches=[batch.id for batch in job.batches]
424+
batches=batch_ids if batch_ids is not None else [batch.id for batch in job.batches]
365425
)
366426

367427
class CurrentJobReturn(BaseModel):
@@ -377,31 +437,127 @@ async def current_job() -> CurrentJobReturn:
377437
}
378438

379439
@app.get("/job/{job_id}")
380-
async def get_job(job_id: int) -> JobReturn:
440+
async def get_job(job_id: Annotated[int, Path(title="Job ID", description="The ID of the job you want info on", ge=1)]) -> JobReturn:
381441
async with async_session() as session, session.begin():
382-
stmt = select(Job).where(Job.id == job_id).limit(1)
442+
stmt = select(Job).where(Job.id == job_id).limit(1).options(sqlalchemy.orm.joinedload(Job.batches))
383443
job = await session.scalar(stmt)
384444
if job is None:
385445
raise HTTPException(status_code=404, detail="Job not found")
386446
return JobReturn.from_job(job)
387447

388448
class BatchReturn(BaseModel):
389449
id: int
450+
"""The ID of the batch"""
390451
created_at: datetime.datetime
391-
jobs: list[JobReturn] = []
452+
"""The time the batch was created"""
453+
repeat_url: int | None = None
454+
"""The ID of the repeat URL that the batch represents, if any"""
455+
jobs: int
456+
"""The number of jobs in the batch"""
392457

393458
@app.get("/batch/{batch_id}")
394-
async def get_batch(batch_id: int):
459+
async def get_batch(batch_id: Annotated[int, Path(title="Batch ID", description="The ID of the batch you want info on", ge=1)]) -> BatchReturn:
395460
async with async_session() as session, session.begin():
396461
stmt = select(Batch).where(Batch.id == batch_id).limit(1)
397462
batch = await session.scalar(stmt)
398463
if batch is None:
399464
raise HTTPException(status_code=404, detail="Batch not found")
465+
stmt = select(RepeatURL.id).join(RepeatURL.batch).where(Batch.id == batch_id).limit(1)
466+
repeat_url = await session.scalar(stmt)
467+
job_count = await session.scalar(select(sqlalchemy.func.count(Job.id)).join(Batch.jobs).where(Batch.id == batch_id))
400468
return {
401469
"id": batch.id,
402470
"created_at": batch.created_at,
403-
"jobs": [JobReturn.from_job(job) for job in batch.jobs]
471+
"repeat_url": repeat_url,
472+
"jobs": job_count
404473
}
474+
475+
class PaginationInfo(BaseModel):
476+
current_page: int
477+
total_pages: int
478+
items: int
479+
480+
ModelT = TypeVar("ModelT", bound=BaseModel)
481+
482+
class PaginationOutput(BaseModel, Generic[ModelT]):
483+
data: list[ModelT]
484+
pagination: PaginationInfo
485+
486+
Page = Annotated[int, Query(title="Page", description="The page of results you want", ge=1, le=100)]
487+
488+
@app.get("/jobs")
489+
async def get_jobs(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
490+
job_count_cache = cache["job_count"]
491+
cache_key = after
492+
async with async_session() as session, session.begin():
493+
if (job_count := await job_count_cache.get(cache_key)) is None:
494+
stmt = select(sqlalchemy.func.count(Job.id))
495+
if after:
496+
stmt = stmt.where(Job.created_at > after)
497+
job_count = await session.scalar(stmt)
498+
await job_count_cache.set(cache_key, job_count)
499+
stmt = select(Job).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
500+
if after:
501+
stmt = stmt.where(Job.created_at > after)
502+
result = await session.scalars(stmt)
503+
return PaginationOutput(
504+
data=[JobReturn.from_job(job) for job in result.unique().all()],
505+
pagination=PaginationInfo(
506+
current_page=page,
507+
total_pages=job_count // 100 + 1,
508+
items=job_count
509+
)
510+
)
511+
512+
@app.get("/batch/{batch_id}/jobs")
513+
async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", description="The ID of the batch you want info on", ge=1)], page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
514+
job_count_cache = cache["batch_job_count"]
515+
cache_key = (batch_id, after)
516+
async with async_session() as session, session.begin():
517+
if (job_count := await job_count_cache.get(cache_key)) is None:
518+
stmt = select(sqlalchemy.func.count(Job.id)).join(Batch.jobs).where(Batch.id == batch_id)
519+
if after:
520+
stmt = stmt.where(Job.created_at > after)
521+
job_count = await session.scalar(stmt)
522+
await job_count_cache.set(cache_key, job_count)
523+
stmt = select(Job).join(Batch.jobs).where(Batch.id == batch_id).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
524+
if after:
525+
stmt = stmt.where(Job.created_at > after)
526+
result = await session.scalars(stmt)
527+
return PaginationOutput(
528+
data=[JobReturn.from_job(job) for job in result.unique().all()],
529+
pagination=PaginationInfo(
530+
current_page=page,
531+
total_pages=job_count // 100 + 1,
532+
items=job_count
533+
)
534+
)
535+
536+
@app.get("/batches")
537+
async def get_batches(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[BatchReturn]:
538+
batch_count_cache = cache["batch_count"]
539+
cache_key = after
540+
async with async_session() as session, session.begin():
541+
if (batch_count := await batch_count_cache.get(cache_key)) is None:
542+
stmt = select(sqlalchemy.func.count(Batch.id))
543+
if after:
544+
stmt = stmt.where(Batch.created_at > after)
545+
batch_count = await session.scalar(stmt)
546+
await batch_count_cache.set(cache_key, batch_count)
547+
stmt = select(Batch).order_by(Batch.id.desc() if desc else Batch.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Batch.jobs))
548+
if after:
549+
stmt = stmt.where(Batch.created_at > after)
550+
result = await session.scalars(stmt)
551+
return PaginationOutput(
552+
data=[BatchReturn(id=batch.id, created_at=batch.created_at, jobs=len(batch.jobs)) for batch in result.unique().all()],
553+
pagination=PaginationInfo(
554+
current_page=page,
555+
total_pages=batch_count // 100 + 1,
556+
items=batch_count
557+
)
558+
)
559+
560+
405561
class URLInfoBody(BaseModel):
406562
url: str
407563

@@ -413,7 +569,7 @@ class URLReturn(BaseModel):
413569
@app.post("/url")
414570
async def get_url_info(body: URLInfoBody) -> URLReturn:
415571
async with async_session() as session, session.begin():
416-
stmt = select(URL).where(URL.url == body.url).limit(1)
572+
stmt = select(URL).where(URL.url == body.url).limit(1).options(sqlalchemy.orm.joinedload(URL.jobs)).options(sqlalchemy.orm.joinedload(URL.jobs, Job.batches))
417573
url = await session.scalar(stmt)
418574
if url is None:
419575
raise HTTPException(status_code=404, detail="URL not found")

0 commit comments

Comments
 (0)