Skip to content

Commit 372cba5

Browse files
committed
Add some new features
1 parent 50163a3 commit 372cba5

File tree

1 file changed

+136
-23
lines changed

1 file changed

+136
-23
lines changed

src/main.py

Lines changed: 136 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
from os import environ
55
import re
6-
from typing import Annotated, Awaitable, Generic, TypeVar
6+
from typing import Annotated, Awaitable, Generic, Literal, TypeVar, TypedDict, overload
77
from traceback import print_exc
88

99
from pydantic import BaseModel
@@ -12,7 +12,7 @@
1212
import sqlalchemy.ext.asyncio
1313
from sqlalchemy import select, update
1414
from sqlalchemy.orm import Mapped, mapped_column
15-
from fastapi import FastAPI, HTTPException, Path, Query
15+
from fastapi import Depends, FastAPI, HTTPException, Path, Query
1616
from aiohttp import ClientSession
1717

1818
engine: sqlalchemy.ext.asyncio.AsyncEngine = None
@@ -122,7 +122,7 @@ async def url_worker():
122122
saved_dt = datetime.datetime.strptime(match.group(1), "%Y%m%d%H%M%S").replace(tzinfo=datetime.timezone.utc)
123123
async with session.begin():
124124
await session.execute(update(URL).where(URL.id == next_job.url.id).values(last_seen=saved_dt))
125-
await session.execute(update(Job).where(Job.id == next_job.id).values(completed=saved_dt))
125+
await session.execute(update(Job).where(Job.id == next_job.id).values(completed=saved_dt, failed=None, delayed_until=None))
126126
break
127127
except Exception:
128128
pass
@@ -133,7 +133,7 @@ async def url_worker():
133133
print(f"Retrying job id={next_job.id} for the {next_job.retry + 1} time.")
134134
await session.execute(update(Job).where(Job.id == next_job.id).values(retry=next_job.retry + 1, delayed_until=curtime + datetime.timedelta(minutes=45)))
135135
else:
136-
await session.execute(update(Job).where(Job.id == next_job.id).values(failed=curtime))
136+
await session.execute(update(Job).where(Job.id == next_job.id).values(failed=curtime, delayed_until=None))
137137

138138
async def repeat_url_worker():
139139
batch = None
@@ -430,48 +430,116 @@ class PaginationOutput(BaseModel, Generic[ModelT]):
430430

431431
Page = Annotated[int, Query(title="Page", description="The page of results you want", ge=1, le=100)]
432432

433+
class PaginationDefaultQueryArgs(TypedDict):
434+
page: Page
435+
after: datetime.datetime | None
436+
desc: bool
437+
438+
async def pagination_default_query_args(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationDefaultQueryArgs:
439+
return {
440+
"page": page,
441+
"after": after,
442+
"desc": desc
443+
}
444+
445+
PaginationQueryArgs = Annotated[PaginationDefaultQueryArgs, Depends(pagination_default_query_args)]
446+
447+
class JobPaginationDefaultQueryArgs(PaginationDefaultQueryArgs):
448+
not_started: bool
449+
completed: bool
450+
delayed: bool
451+
failed: bool
452+
retries_less_than: Literal[1, 2, 3, 4] | None
453+
retries_greater_than: Literal[0, 1, 2, 3] | None
454+
retries_equal_to: Literal[0, 1, 2, 3, 4] | None
455+
456+
async def job_pagination_default_query_args(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False, not_started: bool = False, completed: bool = False, delayed: bool = False, failed: bool = False, retries_less_than: Literal[1, 2, 3, 4] | None = None, retries_greater_than: Literal[0, 1, 2, 3] | None = None, retries_equal_to: Literal[0, 1, 2, 3, 4] | None = None) -> JobPaginationDefaultQueryArgs:
457+
return {
458+
"page": page,
459+
"after": after,
460+
"desc": desc,
461+
"not_started": not_started,
462+
"completed": completed,
463+
"delayed": delayed,
464+
"failed": failed,
465+
"retries_less_than": retries_less_than,
466+
"retries_greater_than": retries_greater_than,
467+
"retries_equal_to": retries_equal_to
468+
}
469+
470+
JobPaginationQueryArgs = Annotated[JobPaginationDefaultQueryArgs, Depends(job_pagination_default_query_args)]
471+
472+
@overload
473+
def apply_job_filtering(query_params: JobPaginationDefaultQueryArgs, is_count_query: Literal[True], /) -> sqlalchemy.Select[tuple[int]]: ...
474+
@overload
475+
def apply_job_filtering(query_params: JobPaginationDefaultQueryArgs, is_count_query: Literal[False], /) -> sqlalchemy.Select[tuple[Job]]: ...
476+
def apply_job_filtering(query_params: JobPaginationDefaultQueryArgs, is_count_query: bool = False, /) -> sqlalchemy.Select:
477+
in_statement = select(sqlalchemy.func.count(Job.id) if is_count_query else Job)
478+
if [query_params["not_started"], query_params["completed"], query_params["delayed"], query_params["failed"]].count(True) != 4:
479+
# If all 4 are given, we can take a shortcut and not apply anything
480+
if query_params["not_started"]:
481+
in_statement = in_statement.where((Job.completed == None) & (Job.failed == None) & (Job.delayed_until == None))
482+
if query_params["completed"]:
483+
in_statement = in_statement.where(Job.completed != None)
484+
if query_params["delayed"]:
485+
in_statement = in_statement.where(Job.delayed_until != None)
486+
if query_params["failed"]:
487+
in_statement = in_statement.where(Job.failed != None)
488+
retry_param_count = [query_params["retries_less_than"], query_params["retries_greater_than"], query_params["retries_equal_to"]].count(None)
489+
if retry_param_count <= 2:
490+
raise HTTPException(status_code=400, detail="You must provide only one of retries_less_than, retries_greater_than, or retries_equal_to")
491+
elif retry_param_count != 3:
492+
if query_params["retries_less_than"] is not None:
493+
in_statement = in_statement.where(Job.retry < query_params["retries_less_than"])
494+
if query_params["retries_greater_than"] is not None:
495+
in_statement = in_statement.where(Job.retry > query_params["retries_greater_than"])
496+
if query_params["retries_equal_to"] is not None:
497+
in_statement = in_statement.where(Job.retry == query_params["retries_equal_to"])
498+
if query_params["after"]:
499+
in_statement = in_statement.where(Job.created_at > query_params["after"])
500+
if not is_count_query:
501+
in_statement = in_statement.limit(100).order_by(Job.id.desc() if query_params["desc"] else Job.id.asc()).offset((query_params["page"] - 1) * 100)
502+
return in_statement
503+
504+
505+
506+
433507
@app.get("/jobs")
434-
async def get_jobs(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
508+
async def get_jobs(query_params: JobPaginationQueryArgs) -> PaginationOutput[JobReturn]:
435509
async with async_session() as session, session.begin():
436-
stmt = select(sqlalchemy.func.count(Job.id))
437-
if after:
438-
stmt = stmt.where(Job.created_at > after)
439-
job_count = await session.scalar(stmt)
440-
stmt2 = select(Job).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
441-
if after:
442-
stmt2 = stmt2.where(Job.created_at > after)
443-
result = await session.scalars(stmt2)
510+
job_count = await session.scalar(apply_job_filtering(query_params, True))
511+
stmt = apply_job_filtering(query_params, False).options(sqlalchemy.orm.joinedload(Job.batches))
512+
result = await session.scalars(stmt)
444513
return PaginationOutput(
445514
data=[JobReturn.from_job(job) for job in result.unique().all()],
446515
pagination=PaginationInfo(
447-
current_page=page,
516+
current_page=query_params["page"],
448517
total_pages=job_count // 100 + 1,
449518
items=job_count
450519
)
451520
)
452521

453522
@app.get("/batch/{batch_id}/jobs")
454-
async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", description="The ID of the batch you want info on", ge=1)], page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
523+
async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", description="The ID of the batch you want info on", ge=1)], query_params: JobPaginationQueryArgs) -> PaginationOutput[JobReturn]:
455524
async with async_session() as session, session.begin():
456-
stmt = select(sqlalchemy.func.count(Job.id)).join(Batch.jobs).where(Batch.id == batch_id)
457-
if after:
458-
stmt = stmt.where(Job.created_at > after)
525+
stmt = apply_job_filtering(query_params, True).join(Batch.jobs).where(Batch.id == batch_id)
459526
job_count = await session.scalar(stmt)
460-
stmt2 = select(Job).join(Batch.jobs).where(Batch.id == batch_id).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
461-
if after:
462-
stmt2 = stmt2.where(Job.created_at > after)
527+
stmt2 = apply_job_filtering(query_params, False).join(Batch.jobs).where(Batch.id == batch_id).options(sqlalchemy.orm.joinedload(Job.batches))
463528
result = await session.scalars(stmt2)
464529
return PaginationOutput(
465530
data=[JobReturn.from_job(job) for job in result.unique().all()],
466531
pagination=PaginationInfo(
467-
current_page=page,
532+
current_page=query_params["page"],
468533
total_pages=job_count // 100 + 1,
469534
items=job_count
470535
)
471536
)
472537

473538
@app.get("/batches")
474-
async def get_batches(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[BatchReturn]:
539+
async def get_batches(query_params: PaginationQueryArgs) -> PaginationOutput[BatchReturn]:
540+
after = query_params["after"]
541+
page = query_params["page"]
542+
desc = query_params["desc"]
475543
async with async_session() as session, session.begin():
476544
stmt = select(sqlalchemy.func.count(Batch.id))
477545
if after:
@@ -490,6 +558,51 @@ async def get_batches(page: Page = 1, after: datetime.datetime | None = None, de
490558
)
491559
)
492560

561+
@app.get("/repeat_urls")
562+
async def get_repeat_urls(query_params: PaginationQueryArgs) -> PaginationOutput[RepeatURL]:
563+
after = query_params["after"]
564+
page = query_params["page"]
565+
desc = query_params["desc"]
566+
async with async_session() as session, session.begin():
567+
stmt = select(sqlalchemy.func.count(RepeatURL.id))
568+
if after:
569+
stmt = stmt.where(RepeatURL.created_at > after)
570+
repeat_url_count = await session.scalar(stmt)
571+
stmt2 = select(RepeatURL).order_by(RepeatURL.id.desc() if desc else RepeatURL.id.asc()).offset((page - 1) * 100).limit(100)
572+
if after:
573+
stmt2 = stmt2.where(RepeatURL.created_at > after)
574+
result = await session.scalars(stmt2)
575+
return PaginationOutput(
576+
data=result.all(),
577+
pagination=PaginationInfo(
578+
current_page=page,
579+
total_pages=repeat_url_count // 100 + 1,
580+
items=repeat_url_count
581+
)
582+
)
583+
584+
@app.get("/urls")
585+
async def get_urls(query_params: PaginationQueryArgs, unique: bool = True) -> PaginationOutput[URL]:
586+
after = query_params["after"]
587+
page = query_params["page"]
588+
desc = query_params["desc"]
589+
async with async_session() as session, session.begin():
590+
stmt = select(sqlalchemy.func.count(URL.id))
591+
if after:
592+
stmt = stmt.where(URL.first_seen > after)
593+
url_count = await session.scalar(stmt)
594+
stmt2 = select(URL).order_by(URL.id.desc() if desc else URL.id.asc()).offset((page - 1) * 100).limit(100)
595+
if after:
596+
stmt2 = stmt2.where(URL.first_seen > after)
597+
result = await session.scalars(stmt2)
598+
return PaginationOutput(
599+
data=result.unique().all(),
600+
pagination=PaginationInfo(
601+
current_page=page,
602+
total_pages=url_count // 100 + 1,
603+
items=url_count
604+
)
605+
)
493606

494607
class URLInfoBody(BaseModel):
495608
url: str

0 commit comments

Comments
 (0)