diff --git a/tasktiger/redis_scripts.py b/tasktiger/redis_scripts.py index 64f8c31..5b93dfa 100644 --- a/tasktiger/redis_scripts.py +++ b/tasktiger/redis_scripts.py @@ -260,6 +260,22 @@ assert(redis.call('zscore', KEYS[1], ARGV[1]), '') """ +# KEYS = { scheduled_zset_key, task_data_key } +# ARGV = { score, member } +# score is used both as the new ZSET score and as the new scheduled_at value. +UPDATE_SCHEDULED_TIME = """ + local score = redis.call('zscore', KEYS[1], ARGV[2]) + if score then + redis.call('zadd', KEYS[1], ARGV[1], ARGV[2]) + local data = cjson.decode(redis.call('get', KEYS[2])) + data['scheduled_at'] = tonumber(ARGV[1]) + redis.call('set', KEYS[2], cjson.encode(data)) + return 1 + else + return 0 + end +""" + # KEYS = { } # ARGV = { key_prefix, time, batch_size } GET_EXPIRED_TASKS = """ @@ -321,6 +337,8 @@ def __init__(self, redis: Redis) -> None: self._get_expired_tasks = redis.register_script(GET_EXPIRED_TASKS) + self._update_scheduled_time = redis.register_script(UPDATE_SCHEDULED_TIME) + self._move_task = self.register_script_from_file( "lua/move_task.lua", include_functions={ @@ -566,6 +584,24 @@ def get_expired_tasks( # [queue1, task1, queue2, task2] -> [(queue1, task1), (queue2, task2)] return list(zip(result[::2], result[1::2])) + def update_scheduled_time( + self, + scheduled_zset_key: str, + task_data_key: str, + score: float, + member: str, + ) -> bool: + """ + Atomically updates a task's scheduled time in the ZSET and patches + scheduled_at in the _data blob. Returns True if the task was found and + updated, False if the task was not present in the ZSET. + """ + result = self._update_scheduled_time( + keys=[scheduled_zset_key, task_data_key], + args=[score, member], + ) + return bool(result) + def move_task( self, id: str, diff --git a/tasktiger/task.py b/tasktiger/task.py index 4c43547..2343ba6 100644 --- a/tasktiger/task.py +++ b/tasktiger/task.py @@ -183,6 +183,18 @@ def time_last_queued(self) -> Optional[datetime.datetime]: else: return datetime.datetime.utcfromtimestamp(timestamp) + @property + def scheduled_at(self) -> Optional[datetime.datetime]: + """ + The timestamp (datetime) of when the task was intended to run — either + the `when` value passed to `delay()`, or the time `delay()` was called + if no `when` was given. Returns None if the task has never been queued. + """ + timestamp = self._data.get("scheduled_at") + if timestamp is None: + return None + return datetime.datetime.utcfromtimestamp(timestamp) + @property def state(self) -> str: return self._state @@ -373,6 +385,8 @@ def delay( else: state = SCHEDULED + self._data["scheduled_at"] = ts + # When using ALWAYS_EAGER, make sure we have serialized the task to # ensure there are no serialization errors. serialized_task = json.dumps(self._data) @@ -417,18 +431,20 @@ def update_scheduled_time( ts = get_timestamp(when) assert ts - pipeline = tiger.connection.pipeline() - key = tiger._key(SCHEDULED, self.queue) - tiger.scripts.zadd(key, ts, self.id, mode="xx", client=pipeline) - pipeline.zscore(key, self.id) - _, score = pipeline.execute() - if not score: + found = tiger.scripts.update_scheduled_time( + scheduled_zset_key=tiger._key(SCHEDULED, self.queue), + task_data_key=tiger._key("task", self.id), + score=ts, + member=self.id, + ) + if not found: raise TaskNotFound( 'Task {} not found in queue "{}" in state "{}".'.format( self.id, self.queue, SCHEDULED ) ) + self._data["scheduled_at"] = ts self._ts = ts def __repr__(self) -> str: diff --git a/tests/test_task.py b/tests/test_task.py index ad4c475..2b25c5c 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,4 +1,7 @@ +import datetime + import pytest +from freezefrog import FreezeTime from tasktiger import Task, TaskNotFound @@ -49,3 +52,51 @@ def some_task(): task = tiger.delay(some_task, max_stored_executions=11) assert task.max_stored_executions == 11 + + +class TestScheduledAt: + FROZEN_NOW = datetime.datetime(2024, 1, 1, 12, 0, 0) + + def test_immediate_task_scheduled_at_equals_queue_time(self, tiger): + with FreezeTime(self.FROZEN_NOW): + task = tiger.delay(simple_task) + assert task.scheduled_at == self.FROZEN_NOW + + def test_future_task_scheduled_at_equals_when(self, tiger): + future = datetime.timedelta(minutes=5) + with FreezeTime(self.FROZEN_NOW): + task = tiger.delay(simple_task, when=future) + assert task.scheduled_at == self.FROZEN_NOW + future + + def test_scheduled_at_survives_scheduled_to_queued_transition(self, tiger): + future = datetime.timedelta(minutes=5) + with FreezeTime(self.FROZEN_NOW): + task = tiger.delay(simple_task, when=future) + expected = self.FROZEN_NOW + future + + task._move(from_state="scheduled", to_state="queued") + reloaded = Task.from_id(tiger, task.queue, "queued", task.id) + assert reloaded.scheduled_at == expected + + def test_scheduled_at_persists_after_reload(self, tiger): + with FreezeTime(self.FROZEN_NOW): + task = tiger.delay(simple_task) + reloaded = Task.from_id(tiger, task.queue, "queued", task.id) + assert reloaded.scheduled_at == self.FROZEN_NOW + + def test_scheduled_at_none_for_unqueued_task(self, tiger): + task = Task(tiger, simple_task) + assert task.scheduled_at is None + + def test_update_scheduled_time_updates_scheduled_at(self, tiger): + future = datetime.timedelta(minutes=5) + later = datetime.timedelta(minutes=10) + with FreezeTime(self.FROZEN_NOW): + task = tiger.delay(simple_task, when=future) + assert task.scheduled_at == self.FROZEN_NOW + future + + new_when = self.FROZEN_NOW + later + task.update_scheduled_time(when=new_when) + + reloaded = Task.from_id(tiger, task.queue, "scheduled", task.id) + assert reloaded.scheduled_at == new_when