Skip to content

Commit 6d96607

Browse files
authored
Support for custom task runners (#175)
Allow specifying a Python class to influence task running behavior.
1 parent 6064c95 commit 6d96607

File tree

7 files changed

+223
-14
lines changed

7 files changed

+223
-14
lines changed

README.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,28 @@ The following options are supported by both ``delay`` and the task decorator:
337337
For example, to retry a task 3 times (for a total of 4 executions), and wait
338338
60 seconds between executions, pass ``retry_method=fixed(60, 3)``.
339339

340+
- ``runner_class``
341+
342+
If given, a Python class can be specified to influence task running behavior.
343+
The runner class should inherit ``tasktiger.runner.BaseRunner`` and implement
344+
the task execution behavior. The default implementation is available in
345+
``tasktiger.runner.DefaultRunner``. The following behavior can be achieved:
346+
347+
- Execute specific code before or after the task is executed (in the forked
348+
child process), or customize the way task functions are called in either
349+
single or batch processing.
350+
351+
Note that if you want to execute specific code for all tasks,
352+
you should use the ``CHILD_CONTEXT_MANAGERS`` configuration option.
353+
354+
- Control the hard timeout behavior of a task.
355+
356+
- Execute specific code in the main worker process after a task failed
357+
permanently.
358+
359+
This is an advanced feature and the interface and requirements of the runner
360+
class can change in future TaskTiger versions.
361+
340362
The following options can be only specified in the task decorator:
341363

342364
- ``batch``
@@ -408,6 +430,7 @@ Example usage:
408430
.. code:: python
409431
410432
from tasktiger.exceptions import RetryException
433+
from tasktiger.retry import exponential, fixed
411434
412435
def my_task():
413436
if not ready():

tasktiger/runner.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from ._internal import import_attribute
2+
from .exceptions import TaskImportError
3+
from .timeouts import UnixSignalDeathPenalty
4+
5+
6+
class BaseRunner:
7+
"""
8+
Base implementation of the task runner.
9+
"""
10+
11+
def __init__(self, tiger):
12+
self.tiger = tiger
13+
14+
def run_single_task(self, task, hard_timeout):
15+
"""
16+
Run the given task using the hard timeout in seconds.
17+
18+
This is called inside of the forked process.
19+
"""
20+
raise NotImplementedError("Single tasks are not supported.")
21+
22+
def run_batch_tasks(self, tasks, hard_timeout):
23+
"""
24+
Run the given tasks using the hard timeout in seconds.
25+
26+
This is called inside of the forked process.
27+
"""
28+
raise NotImplementedError("Batch tasks are not supported.")
29+
30+
def run_eager_task(self, task):
31+
"""
32+
Run the task eagerly and return the value.
33+
34+
Note that the task function could be a batch function.
35+
"""
36+
raise NotImplementedError("Eager tasks are not supported.")
37+
38+
def on_permanent_error(self, task, execution):
39+
"""
40+
Called if the task fails permanently.
41+
42+
A task fails permanently if its status is set to ERROR and it is no
43+
longer retried.
44+
45+
This is called in the main worker process.
46+
"""
47+
48+
49+
class DefaultRunner(BaseRunner):
50+
"""
51+
Default implementation of the task runner.
52+
"""
53+
54+
def run_single_task(self, task, hard_timeout):
55+
with UnixSignalDeathPenalty(hard_timeout):
56+
task.func(*task.args, **task.kwargs)
57+
58+
def run_batch_tasks(self, tasks, hard_timeout):
59+
params = [{'args': task.args, 'kwargs': task.kwargs} for task in tasks]
60+
func = tasks[0].func
61+
with UnixSignalDeathPenalty(hard_timeout):
62+
func(params)
63+
64+
def run_eager_task(self, task):
65+
func = task.func
66+
is_batch_func = getattr(func, '_task_batch', False)
67+
68+
if is_batch_func:
69+
return func([{'args': task.args, 'kwargs': task.kwargs}])
70+
else:
71+
return func(*task.args, **task.kwargs)
72+
73+
74+
def get_runner_class(log, tasks):
75+
runner_class_paths = {task.serialized_runner_class for task in tasks}
76+
if len(runner_class_paths) > 1:
77+
log.error(
78+
"cannot mix multiple runner classes",
79+
runner_class_paths=", ".join(str(p) for p in runner_class_paths),
80+
)
81+
raise ValueError("Found multiple runner classes in batch task.")
82+
83+
runner_class_path = runner_class_paths.pop()
84+
if runner_class_path:
85+
try:
86+
return import_attribute(runner_class_path)
87+
except TaskImportError:
88+
log.error('could not import runner class', func=retry_func)
89+
raise
90+
return DefaultRunner

tasktiger/task.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import time
66

77
from ._internal import *
8-
from .exceptions import QueueFullException, TaskNotFound
8+
from .exceptions import QueueFullException, TaskImportError, TaskNotFound
9+
from .runner import get_runner_class
910

1011
__all__ = ['Task']
1112

@@ -26,6 +27,7 @@ def __init__(
2627
retry_on=None,
2728
retry_method=None,
2829
max_queue_size=None,
30+
runner_class=None,
2931
# internal variables
3032
_data=None,
3133
_state=None,
@@ -76,6 +78,9 @@ def __init__(
7678
if max_queue_size is None:
7779
max_queue_size = getattr(func, '_task_max_queue_size', None)
7880

81+
if runner_class is None:
82+
runner_class = getattr(func, '_task_runner_class', None)
83+
7984
# normalize falsy args/kwargs to empty structures
8085
args = args or []
8186
kwargs = kwargs or {}
@@ -110,6 +115,9 @@ def __init__(
110115
]
111116
if max_queue_size:
112117
task['max_queue_size'] = max_queue_size
118+
if runner_class:
119+
serialized_runner_class = serialize_func_name(runner_class)
120+
task['runner_class'] = serialized_runner_class
113121

114122
self._data = task
115123

@@ -191,6 +199,10 @@ def func(self):
191199
self._func = import_attribute(self.serialized_func)
192200
return self._func
193201

202+
@property
203+
def serialized_runner_class(self):
204+
return self._data.get('runner_class')
205+
194206
@property
195207
def ts(self):
196208
"""
@@ -298,10 +310,9 @@ def execute(self):
298310
g['tiger'] = self.tiger
299311

300312
try:
301-
if is_batch_func:
302-
return func([{'args': self.args, 'kwargs': self.kwargs}])
303-
else:
304-
return func(*self.args, **self.kwargs)
313+
runner_class = get_runner_class(self.tiger.log, [self])
314+
runner = runner_class(self.tiger)
315+
return runner.run_eager_task(self)
305316
finally:
306317
g['current_task_is_batch'] = None
307318
g['current_tasks'] = None

tasktiger/tasktiger.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def task(
276276
schedule=None,
277277
batch=False,
278278
max_queue_size=None,
279+
runner_class=None,
279280
):
280281
"""
281282
Function decorator that defines the behavior of the function when it is
@@ -318,6 +319,8 @@ def _wrap(func):
318319
func._task_schedule = schedule
319320
if max_queue_size is not None:
320321
func._task_max_queue_size = max_queue_size
322+
if runner_class is not None:
323+
func._task_runner_class = runner_class
321324

322325
func.delay = _delay(func)
323326

@@ -389,6 +392,7 @@ def delay(
389392
retry_on=None,
390393
retry_method=None,
391394
max_queue_size=None,
395+
runner_class=None,
392396
):
393397
"""
394398
Queues a task. See README.rst for an explanation of the options.
@@ -407,6 +411,7 @@ def delay(
407411
retry=retry,
408412
retry_on=retry_on,
409413
retry_method=retry_method,
414+
runner_class=runner_class,
410415
)
411416

412417
task.delay(when=when, max_queue_size=max_queue_size)

tasktiger/worker.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from .exceptions import RetryException, TaskNotFound
2121
from .redis_semaphore import Semaphore
2222
from .retry import *
23+
from .runner import get_runner_class
2324
from .stats import StatsThread
2425
from .task import Task
25-
from .timeouts import UnixSignalDeathPenalty, JobTimeoutException
26+
from .timeouts import JobTimeoutException
2627

2728
if sys.version_info < (3, 3):
2829
from contextlib2 import ExitStack
@@ -362,6 +363,9 @@ def _execute_forked(self, tasks, log):
362363
try:
363364
func = tasks[0].func
364365

366+
runner_class = get_runner_class(log, tasks)
367+
runner = runner_class(self.tiger)
368+
365369
is_batch_func = getattr(func, '_task_batch', False)
366370
g['tiger'] = self.tiger
367371
g['current_task_is_batch'] = is_batch_func
@@ -371,10 +375,6 @@ def _execute_forked(self, tasks, log):
371375
):
372376
if is_batch_func:
373377
# Batch process if the task supports it.
374-
params = [
375-
{'args': task.args, 'kwargs': task.kwargs}
376-
for task in tasks
377-
]
378378
task_timeouts = [
379379
task.hard_timeout
380380
for task in tasks
@@ -387,8 +387,7 @@ def _execute_forked(self, tasks, log):
387387
)
388388

389389
g['current_tasks'] = tasks
390-
with UnixSignalDeathPenalty(hard_timeout):
391-
func(params)
390+
runner.run_batch_tasks(tasks, hard_timeout)
392391

393392
else:
394393
# Process sequentially.
@@ -400,8 +399,7 @@ def _execute_forked(self, tasks, log):
400399
)
401400

402401
g['current_tasks'] = [task]
403-
with UnixSignalDeathPenalty(hard_timeout):
404-
func(*task.args, **task.kwargs)
402+
runner.run_single_task(task, hard_timeout)
405403

406404
except RetryException as exc:
407405
execution['retry'] = True
@@ -1015,6 +1013,10 @@ def _mark_done():
10151013
_mark_done()
10161014
else:
10171015
task._move(from_state=ACTIVE, to_state=state, when=when)
1016+
if state == ERROR and task.serialized_runner_class:
1017+
runner_class = get_runner_class(log, [task])
1018+
runner = runner_class(self.tiger)
1019+
runner.on_permanent_error(task, execution)
10181020

10191021
def _worker_run(self):
10201022
"""

tests/tasks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from tasktiger import RetryException, TaskTiger
88
from tasktiger.retry import fixed
9+
from tasktiger.runner import BaseRunner, DefaultRunner
910

1011
from .config import DELAY, TEST_DB, REDIS_HOST
1112
from .utils import get_tiger
@@ -179,3 +180,38 @@ class StaticTask(object):
179180
@staticmethod
180181
def task():
181182
pass
183+
184+
185+
class MyRunnerClass(BaseRunner):
186+
def run_single_task(self, task, hard_timeout):
187+
assert self.tiger.config == tiger.config
188+
assert hard_timeout == 300
189+
assert task.func is simple_task
190+
191+
with redis.Redis(
192+
host=REDIS_HOST, db=TEST_DB, decode_responses=True
193+
) as conn:
194+
conn.set('task_id', task.id)
195+
196+
def run_batch_tasks(self, tasks, hard_timeout):
197+
assert self.tiger.config == tiger.config
198+
assert hard_timeout == 300
199+
assert len(tasks) == 2
200+
201+
with redis.Redis(
202+
host=REDIS_HOST, db=TEST_DB, decode_responses=True
203+
) as conn:
204+
conn.set('task_args', ",".join(str(t.args[0]) for t in tasks))
205+
206+
def run_eager_task(self, task):
207+
return 123
208+
209+
210+
class MyErrorRunnerClass(DefaultRunner):
211+
def on_permanent_error(self, task, execution):
212+
assert task.func is exception_task
213+
assert execution["exception_name"] == "builtins:Exception"
214+
with redis.Redis(
215+
host=REDIS_HOST, db=TEST_DB, decode_responses=True
216+
) as conn:
217+
conn.set('task_id', task.id)

tests/test_base.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
locked_task,
3333
long_task_killed,
3434
long_task_ok,
35+
MyErrorRunnerClass,
36+
MyRunnerClass,
3537
non_batch_task,
3638
retry_task,
3739
retry_task_2,
@@ -1236,3 +1238,43 @@ def fake_error(msg):
12361238
self._ensure_queues()
12371239
assert len(errors) == 1
12381240
assert "not found" in errors[0]
1241+
1242+
1243+
class TestRunnerClass(BaseTestCase):
1244+
def test_custom_runner_class_single_task(self):
1245+
task = self.tiger.delay(simple_task, runner_class=MyRunnerClass)
1246+
Worker(self.tiger).run(once=True)
1247+
assert self.conn.get('task_id') == task.id
1248+
self.conn.delete('task_id')
1249+
self._ensure_queues()
1250+
1251+
def test_custom_runner_class_batch_task(self):
1252+
self.tiger.delay(batch_task, args=[1], runner_class=MyRunnerClass)
1253+
self.tiger.delay(batch_task, args=[2], runner_class=MyRunnerClass)
1254+
Worker(self.tiger).run(once=True)
1255+
assert self.conn.get('task_args') == "1,2"
1256+
self.conn.delete('task_args')
1257+
self._ensure_queues()
1258+
1259+
def test_mixed_runner_class_batch_task(self):
1260+
"""Ensure all tasks in a batch task must have the same runner class."""
1261+
self.tiger.delay(batch_task, args=[1], runner_class=MyRunnerClass)
1262+
self.tiger.delay(batch_task, args=[2])
1263+
Worker(self.tiger).run(once=True)
1264+
assert self.conn.get('task_args') is None
1265+
self._ensure_queues(error={'batch': 2})
1266+
1267+
def test_permanent_error(self):
1268+
task = self.tiger.delay(
1269+
exception_task, runner_class=MyErrorRunnerClass
1270+
)
1271+
Worker(self.tiger).run(once=True)
1272+
assert self.conn.get('task_id') == task.id
1273+
self.conn.delete('task_id')
1274+
self._ensure_queues(error={'default': 1})
1275+
1276+
def test_eager_task(self):
1277+
self.tiger.config['ALWAYS_EAGER'] = True
1278+
task = Task(self.tiger, simple_task, runner_class=MyRunnerClass)
1279+
assert task.delay() == 123
1280+
self._ensure_queues()

0 commit comments

Comments
 (0)