Skip to content

Commit 50163a3

Browse files
committed
Remove caching
1 parent ec50851 commit 50163a3

File tree

1 file changed

+19
-86
lines changed

1 file changed

+19
-86
lines changed

src/main.py

Lines changed: 19 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
from os import environ
55
import re
6-
from typing import Annotated, Awaitable, Generic, TypeVar, TypedDict
6+
from typing import Annotated, Awaitable, Generic, TypeVar
77
from traceback import print_exc
88

99
from pydantic import BaseModel
@@ -15,63 +15,9 @@
1515
from fastapi import FastAPI, HTTPException, Path, Query
1616
from aiohttp import ClientSession
1717

18-
KeyT = TypeVar("KeyT")
19-
ValueT = TypeVar("ValueT")
20-
21-
# LRUCache that has sections is accessed asynchonously
22-
class LRUCache(Generic[KeyT, ValueT]):
23-
def __init__(self, max_size: int = 1024):
24-
self.max_size = max_size
25-
self._cache = {}
26-
self._lru = []
27-
self._lock = asyncio.Lock()
28-
29-
async def get(self, key: KeyT) -> ValueT | None:
30-
async with self._lock:
31-
if key in self._cache:
32-
self._lru.remove(key)
33-
self._lru.append(key)
34-
return self._cache[key]
35-
else:
36-
return None
37-
38-
async def set(self, key: KeyT, value: ValueT) -> None:
39-
async with self._lock:
40-
if key in self._cache:
41-
self._lru.remove(key)
42-
elif len(self._cache) >= self.max_size:
43-
del self._cache[self._lru.pop(0)]
44-
self._cache[key] = value
45-
self._lru.append(key)
46-
47-
async def delete(self, key: KeyT) -> None:
48-
async with self._lock:
49-
if key in self._cache:
50-
del self._cache[key]
51-
self._lru.remove(key)
52-
53-
async def prune(self, keep: int | None = 128) -> None:
54-
async with self._lock:
55-
if keep is None:
56-
self._cache = {}
57-
self._lru = []
58-
else:
59-
self._lru = self._lru[-keep:]
60-
self._cache = {key: self._cache[key] for key in self._lru}
61-
62-
class LRUCacheConsumers(TypedDict):
63-
batch_job_count: LRUCache[tuple[int, datetime.datetime | None], int] # Key is (batch_id, after)
64-
job_count: LRUCache[datetime.datetime | None, int] # Key is after
65-
batch_count: LRUCache[datetime.datetime | None, int] # Key is after
66-
url_count: LRUCache[datetime.datetime | None, int] # Key is after
67-
repeat_url_count: LRUCache[datetime.datetime | None, int] # Key is after
68-
69-
7018
engine: sqlalchemy.ext.asyncio.AsyncEngine = None
7119
async_session: sqlalchemy.ext.asyncio.async_sessionmaker[sqlalchemy.ext.asyncio.AsyncSession] = None
7220
client_session: ClientSession = None
73-
cache: LRUCacheConsumers = None
74-
7521
class Base(sqlalchemy.orm.MappedAsDataclass, sqlalchemy.ext.asyncio.AsyncAttrs, sqlalchemy.orm.DeclarativeBase):
7622
pass
7723

@@ -226,7 +172,6 @@ async def lifespan(_: FastAPI):
226172
await conn.run_sync(Base.metadata.create_all)
227173
workers.append(asyncio.create_task(exception_logger(url_worker(), name="url_worker")))
228174
workers.append(asyncio.create_task(exception_logger(repeat_url_worker(), name="repeat_url_worker")))
229-
cache = {key: LRUCache() for key in LRUCacheConsumers.__required_keys__}
230175
try:
231176
yield
232177
finally:
@@ -487,19 +432,15 @@ class PaginationOutput(BaseModel, Generic[ModelT]):
487432

488433
@app.get("/jobs")
489434
async def get_jobs(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
490-
job_count_cache = cache["job_count"]
491-
cache_key = after
492435
async with async_session() as session, session.begin():
493-
if (job_count := await job_count_cache.get(cache_key)) is None:
494-
stmt = select(sqlalchemy.func.count(Job.id))
495-
if after:
496-
stmt = stmt.where(Job.created_at > after)
497-
job_count = await session.scalar(stmt)
498-
await job_count_cache.set(cache_key, job_count)
499-
stmt = select(Job).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
436+
stmt = select(sqlalchemy.func.count(Job.id))
500437
if after:
501438
stmt = stmt.where(Job.created_at > after)
502-
result = await session.scalars(stmt)
439+
job_count = await session.scalar(stmt)
440+
stmt2 = select(Job).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
441+
if after:
442+
stmt2 = stmt2.where(Job.created_at > after)
443+
result = await session.scalars(stmt2)
503444
return PaginationOutput(
504445
data=[JobReturn.from_job(job) for job in result.unique().all()],
505446
pagination=PaginationInfo(
@@ -511,19 +452,15 @@ async def get_jobs(page: Page = 1, after: datetime.datetime | None = None, desc:
511452

512453
@app.get("/batch/{batch_id}/jobs")
513454
async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", description="The ID of the batch you want info on", ge=1)], page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[JobReturn]:
514-
job_count_cache = cache["batch_job_count"]
515-
cache_key = (batch_id, after)
516455
async with async_session() as session, session.begin():
517-
if (job_count := await job_count_cache.get(cache_key)) is None:
518-
stmt = select(sqlalchemy.func.count(Job.id)).join(Batch.jobs).where(Batch.id == batch_id)
519-
if after:
520-
stmt = stmt.where(Job.created_at > after)
521-
job_count = await session.scalar(stmt)
522-
await job_count_cache.set(cache_key, job_count)
523-
stmt = select(Job).join(Batch.jobs).where(Batch.id == batch_id).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
456+
stmt = select(sqlalchemy.func.count(Job.id)).join(Batch.jobs).where(Batch.id == batch_id)
524457
if after:
525458
stmt = stmt.where(Job.created_at > after)
526-
result = await session.scalars(stmt)
459+
job_count = await session.scalar(stmt)
460+
stmt2 = select(Job).join(Batch.jobs).where(Batch.id == batch_id).order_by(Job.id.desc() if desc else Job.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Job.batches))
461+
if after:
462+
stmt2 = stmt2.where(Job.created_at > after)
463+
result = await session.scalars(stmt2)
527464
return PaginationOutput(
528465
data=[JobReturn.from_job(job) for job in result.unique().all()],
529466
pagination=PaginationInfo(
@@ -535,19 +472,15 @@ async def get_batch_jobs(batch_id: Annotated[int, Path(title="Batch ID", descrip
535472

536473
@app.get("/batches")
537474
async def get_batches(page: Page = 1, after: datetime.datetime | None = None, desc: bool = False) -> PaginationOutput[BatchReturn]:
538-
batch_count_cache = cache["batch_count"]
539-
cache_key = after
540475
async with async_session() as session, session.begin():
541-
if (batch_count := await batch_count_cache.get(cache_key)) is None:
542-
stmt = select(sqlalchemy.func.count(Batch.id))
543-
if after:
544-
stmt = stmt.where(Batch.created_at > after)
545-
batch_count = await session.scalar(stmt)
546-
await batch_count_cache.set(cache_key, batch_count)
547-
stmt = select(Batch).order_by(Batch.id.desc() if desc else Batch.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Batch.jobs))
476+
stmt = select(sqlalchemy.func.count(Batch.id))
548477
if after:
549478
stmt = stmt.where(Batch.created_at > after)
550-
result = await session.scalars(stmt)
479+
batch_count = await session.scalar(stmt)
480+
stmt2 = select(Batch).order_by(Batch.id.desc() if desc else Batch.id.asc()).offset((page - 1) * 100).limit(100).options(sqlalchemy.orm.joinedload(Batch.jobs))
481+
if after:
482+
stmt2 = stmt2.where(Batch.created_at > after)
483+
result = await session.scalars(stmt2)
551484
return PaginationOutput(
552485
data=[BatchReturn(id=batch.id, created_at=batch.created_at, jobs=len(batch.jobs)) for batch in result.unique().all()],
553486
pagination=PaginationInfo(

0 commit comments

Comments
 (0)