diff --git a/tasktiger/__init__.py b/tasktiger/__init__.py index 2c81d82a..d37400a3 100644 --- a/tasktiger/__init__.py +++ b/tasktiger/__init__.py @@ -2,6 +2,7 @@ from collections import defaultdict import importlib import logging +from json import JSONDecoder, JSONEncoder import redis import structlog @@ -159,6 +160,13 @@ def __init__(self, connection=None, config=None, setup_structlog=False): # If non-empty, a worker excludes the given queues from processing. 'EXCLUDE_QUEUES': [], + + # Serializer / Deserilaizer to use for serializing/deserializing tasks + + 'JSON_ENCODER': JSONEncoder, + + 'JSON_DECODER': JSONDecoder + } if config: self.config.update(config) diff --git a/tasktiger/_internal.py b/tasktiger/_internal.py index e0fc07ef..43f8b830 100644 --- a/tasktiger/_internal.py +++ b/tasktiger/_internal.py @@ -53,7 +53,7 @@ def gen_id(): """ return binascii.b2a_hex(os.urandom(32)).decode('utf8') -def gen_unique_id(serialized_name, args, kwargs): +def gen_unique_id(serialized_name, args, kwargs, cls=None): """ Generates and returns a hex-encoded 256-bit ID for the given task name and args. Used to generate IDs for unique tasks or for task locks. @@ -62,7 +62,7 @@ def gen_unique_id(serialized_name, args, kwargs): 'func': serialized_name, 'args': args, 'kwargs': kwargs, - }, sort_keys=True).encode('utf8')).hexdigest() + }, sort_keys=True, cls=cls).encode('utf8')).hexdigest() def serialize_func_name(func): """ diff --git a/tasktiger/task.py b/tasktiger/task.py index ca400e5b..9e06d8e4 100644 --- a/tasktiger/task.py +++ b/tasktiger/task.py @@ -28,6 +28,8 @@ def __init__(self, tiger, func=None, args=None, kwargs=None, queue=None, self._state = _state self._ts = _ts self._executions = _executions or [] + self.json_encoder = self.tiger.config.get('JSON_ENCODER') + self.json_decoder = self.tiger.config.get('JSON_DECODER') # Internal initialization based on raw data. if _data is not None: @@ -57,7 +59,7 @@ def __init__(self, tiger, func=None, args=None, kwargs=None, queue=None, retry_method = getattr(func, '_task_retry_method', None) if unique: - task_id = gen_unique_id(serialized_name, args, kwargs) + task_id = gen_unique_id(serialized_name, args, kwargs, cls=self.json_encoder) else: task_id = gen_id() @@ -280,7 +282,7 @@ def delay(self, when=None): # When using ALWAYS_EAGER, make sure we have serialized the task to # ensure there are no serialization errors. - serialized_task = json.dumps(self._data) + serialized_task = json.dumps(self._data, cls=self.json_encoder) if tiger.config['ALWAYS_EAGER'] and state == QUEUED: return self.execute() @@ -341,8 +343,9 @@ def from_id(self, tiger, queue, state, task_id, load_executions=0): serialized_executions = [] # XXX: No timestamp for now if serialized_data: - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + json_decoder = tiger.config.get('JSON_DECODER') + data = json.loads(serialized_data, cls=json_decoder) + executions = [json.loads(e, cls=json_decoder) for e in serialized_executions if e] return Task(tiger, queue=queue, _data=data, _state=state, _executions=executions) else: @@ -370,6 +373,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, tasks = [] + json_decoder = tiger.config.get('JSON_DECODER') + if items: tss = [datetime.datetime.utcfromtimestamp(item[1]) for item in items] if load_executions: @@ -380,8 +385,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, results = pipeline.execute() for serialized_data, serialized_executions, ts in zip(results[0], results[1:], tss): - data = json.loads(serialized_data) - executions = [json.loads(e) for e in serialized_executions if e] + data = json.loads(serialized_data, cls=json_decoder) + executions = [json.loads(e, cls=json_decoder) for e in serialized_executions if e] task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts, _executions=executions) @@ -390,7 +395,7 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000, else: data = tiger.connection.mget([tiger._key('task', item[0]) for item in items]) for serialized_data, ts in zip(data, tss): - data = json.loads(serialized_data) + data = json.loads(serialized_data, cls=json_decoder) task = Task(tiger, queue=queue, _data=data, _state=state, _ts=ts) tasks.append(task) diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 4e4bd967..6d6424b5 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -44,6 +44,8 @@ def __init__(self, tiger, queues=None, exclude_queues=None, self._did_work = True self._last_task_check = 0 self.stats_thread = None + self.json_encoder = tiger.config.get('JSON_ENCODER') + self.json_decoder = tiger.config.get('JSON_DECODER') if queues: self.only_queues = set(queues) @@ -327,7 +329,7 @@ def _execute_forked(self, tasks, log): ''.join(traceback.format_exception(*exc_info)) execution['success'] = success execution['host'] = socket.gethostname() - serialized_execution = json.dumps(execution) + serialized_execution = json.dumps(execution, cls=self.json_encoder) for task in tasks: self.connection.rpush(self._key('task', task.id, 'executions'), serialized_execution) @@ -544,7 +546,7 @@ def _process_queue_tasks(self, queue, queue_lock, task_ids, now, log): tasks = [] for task_id, serialized_task in zip(task_ids, serialized_tasks): if serialized_task: - task_data = json.loads(serialized_task) + task_data = json.loads(serialized_task, cls=self.json_decoder) else: # In the rare case where we don't find the task which is # queued (see ReliabilityTestCase.test_task_disappears), @@ -671,12 +673,14 @@ def _execute_task_group(self, queue, tasks, all_task_ids, queue_lock): task.serialized_func, None, {key: kwargs.get(key) for key in task.lock_key}, + cls=self.json_encoder ) else: lock_id = gen_unique_id( task.serialized_func, task.args, task.kwargs, + cls=self.json_encoder ) if lock_id not in lock_ids: @@ -739,7 +743,7 @@ def _mark_done(): self._key('task', task.id, 'executions'), -1) if execution: - execution = json.loads(execution) + execution = json.loads(execution, cls=self.json_decoder) if execution and execution.get('retry'): if 'retry_method' in execution: