Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tasktiger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import importlib
import logging
from json import JSONDecoder, JSONEncoder
import redis
import structlog

Expand Down Expand Up @@ -159,6 +160,13 @@ def __init__(self, connection=None, config=None, setup_structlog=False):

# If non-empty, a worker excludes the given queues from processing.
'EXCLUDE_QUEUES': [],

# Serializer / Deserilaizer to use for serializing/deserializing tasks

'JSON_ENCODER': JSONEncoder,

'JSON_DECODER': JSONDecoder

}
if config:
self.config.update(config)
Expand Down
4 changes: 2 additions & 2 deletions tasktiger/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def gen_id():
"""
return binascii.b2a_hex(os.urandom(32)).decode('utf8')

def gen_unique_id(serialized_name, args, kwargs):
def gen_unique_id(serialized_name, args, kwargs, cls=None):
"""
Generates and returns a hex-encoded 256-bit ID for the given task name and
args. Used to generate IDs for unique tasks or for task locks.
Expand All @@ -62,7 +62,7 @@ def gen_unique_id(serialized_name, args, kwargs):
'func': serialized_name,
'args': args,
'kwargs': kwargs,
}, sort_keys=True).encode('utf8')).hexdigest()
}, sort_keys=True, cls=cls).encode('utf8')).hexdigest()

def serialize_func_name(func):
"""
Expand Down
19 changes: 12 additions & 7 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, tiger, func=None, args=None, kwargs=None, queue=None,
self._state = _state
self._ts = _ts
self._executions = _executions or []
self.json_encoder = self.tiger.config.get('JSON_ENCODER')
self.json_decoder = self.tiger.config.get('JSON_DECODER')

# Internal initialization based on raw data.
if _data is not None:
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self, tiger, func=None, args=None, kwargs=None, queue=None,
retry_method = getattr(func, '_task_retry_method', None)

if unique:
task_id = gen_unique_id(serialized_name, args, kwargs)
task_id = gen_unique_id(serialized_name, args, kwargs, cls=self.json_encoder)
else:
task_id = gen_id()

Expand Down Expand Up @@ -280,7 +282,7 @@ def delay(self, when=None):

# When using ALWAYS_EAGER, make sure we have serialized the task to
# ensure there are no serialization errors.
serialized_task = json.dumps(self._data)
serialized_task = json.dumps(self._data, cls=self.json_encoder)

if tiger.config['ALWAYS_EAGER'] and state == QUEUED:
return self.execute()
Expand Down Expand Up @@ -341,8 +343,9 @@ def from_id(self, tiger, queue, state, task_id, load_executions=0):
serialized_executions = []
# XXX: No timestamp for now
if serialized_data:
data = json.loads(serialized_data)
executions = [json.loads(e) for e in serialized_executions if e]
json_decoder = tiger.config.get('JSON_DECODER')
data = json.loads(serialized_data, cls=json_decoder)
executions = [json.loads(e, cls=json_decoder) for e in serialized_executions if e]
return Task(tiger, queue=queue, _data=data, _state=state,
_executions=executions)
else:
Expand Down Expand Up @@ -370,6 +373,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000,

tasks = []

json_decoder = tiger.config.get('JSON_DECODER')

if items:
tss = [datetime.datetime.utcfromtimestamp(item[1]) for item in items]
if load_executions:
Expand All @@ -380,8 +385,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000,
results = pipeline.execute()

for serialized_data, serialized_executions, ts in zip(results[0], results[1:], tss):
data = json.loads(serialized_data)
executions = [json.loads(e) for e in serialized_executions if e]
data = json.loads(serialized_data, cls=json_decoder)
executions = [json.loads(e, cls=json_decoder) for e in serialized_executions if e]

task = Task(tiger, queue=queue, _data=data, _state=state,
_ts=ts, _executions=executions)
Expand All @@ -390,7 +395,7 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000,
else:
data = tiger.connection.mget([tiger._key('task', item[0]) for item in items])
for serialized_data, ts in zip(data, tss):
data = json.loads(serialized_data)
data = json.loads(serialized_data, cls=json_decoder)
task = Task(tiger, queue=queue, _data=data, _state=state,
_ts=ts)
tasks.append(task)
Expand Down
10 changes: 7 additions & 3 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, tiger, queues=None, exclude_queues=None,
self._did_work = True
self._last_task_check = 0
self.stats_thread = None
self.json_encoder = tiger.config.get('JSON_ENCODER')
self.json_decoder = tiger.config.get('JSON_DECODER')

if queues:
self.only_queues = set(queues)
Expand Down Expand Up @@ -327,7 +329,7 @@ def _execute_forked(self, tasks, log):
''.join(traceback.format_exception(*exc_info))
execution['success'] = success
execution['host'] = socket.gethostname()
serialized_execution = json.dumps(execution)
serialized_execution = json.dumps(execution, cls=self.json_encoder)
for task in tasks:
self.connection.rpush(self._key('task', task.id, 'executions'),
serialized_execution)
Expand Down Expand Up @@ -544,7 +546,7 @@ def _process_queue_tasks(self, queue, queue_lock, task_ids, now, log):
tasks = []
for task_id, serialized_task in zip(task_ids, serialized_tasks):
if serialized_task:
task_data = json.loads(serialized_task)
task_data = json.loads(serialized_task, cls=self.json_decoder)
else:
# In the rare case where we don't find the task which is
# queued (see ReliabilityTestCase.test_task_disappears),
Expand Down Expand Up @@ -671,12 +673,14 @@ def _execute_task_group(self, queue, tasks, all_task_ids, queue_lock):
task.serialized_func,
None,
{key: kwargs.get(key) for key in task.lock_key},
cls=self.json_encoder
)
else:
lock_id = gen_unique_id(
task.serialized_func,
task.args,
task.kwargs,
cls=self.json_encoder
)

if lock_id not in lock_ids:
Expand Down Expand Up @@ -739,7 +743,7 @@ def _mark_done():
self._key('task', task.id, 'executions'), -1)

if execution:
execution = json.loads(execution)
execution = json.loads(execution, cls=self.json_decoder)

if execution and execution.get('retry'):
if 'retry_method' in execution:
Expand Down