Skip to content

Commit e032002

Browse files
committed
Fix migration
1 parent dba9afd commit e032002

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Add batch locking
2+
3+
Revision ID: 90db4a933a16
4+
Revises: 9a098a265cbd
5+
Create Date: 2024-01-18 14:21:22.487177
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = "90db4a933a16"
16+
down_revision: Union[str, None] = "9a098a265cbd"
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
# ### commands auto generated by Alembic - please adjust! ###
23+
op.add_column(
24+
"batches", sa.Column("locked", sa.DateTime(timezone=True), nullable=True)
25+
)
26+
op.create_index(op.f("ix_batches_locked"), "batches", ["locked"], unique=False)
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade() -> None:
31+
# ### commands auto generated by Alembic - please adjust! ###
32+
op.drop_index(op.f("ix_batches_locked"), table_name="batches")
33+
op.drop_column("batches", "locked")
34+
# ### end Alembic commands ###

src/models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ class Batch(Base):
9090
init=False,
9191
index=True,
9292
)
93-
locked: Mapped[bool] = mapped_column(
94-
default=False
93+
locked: Mapped[datetime.datetime | None] = mapped_column(
94+
sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True
9595
) # Indicates that a batch is locked (no more jobs can be added to it)
9696

9797
jobs: Mapped[list["Job"]] = sqlalchemy.orm.relationship(
@@ -188,6 +188,12 @@ class BatchJobs(Base):
188188
sqlalchemy.UniqueConstraint("batch_id", "job_id", name="_batch_job_uc"),
189189
)
190190

191+
@sqlalchemy.orm.validates("batch")
192+
def validate_not_locked_batch(self, key: str, batch: Batch) -> Batch:
193+
if batch.locked:
194+
raise ValueError("Batch is locked")
195+
return batch
196+
191197

192198
class Job(Base):
193199
__tablename__ = "jobs"
@@ -228,3 +234,9 @@ class Job(Base):
228234
failed: Mapped[datetime.datetime | None] = mapped_column(
229235
sqlalchemy.DateTime(timezone=True), default=None, nullable=True, index=True
230236
) # If a job has failed, this is the time it failed at
237+
238+
@sqlalchemy.orm.validates("batches")
239+
def validate_not_locked_batch(self, key: str, batch: Batch) -> Batch:
240+
if batch.locked:
241+
raise ValueError("Batch is locked")
242+
return batch

0 commit comments

Comments
 (0)