1414from .dask_dag import DaskVineDag
1515from .cvine import VINE_TEMP
1616
17+ import os
18+ import time
19+ import random
1720import contextlib
1821import cloudpickle
19- import os
2022from uuid import uuid4
23+ from collections import defaultdict
2124
2225try :
2326 import rich
@@ -123,6 +126,7 @@ def get(self, dsk, keys, *,
123126 lib_command = None ,
124127 lib_modules = None ,
125128 task_mode = 'tasks' ,
129+ scheduling_mode = 'FIFO' ,
126130 env_per_task = False ,
127131 progress_disable = False ,
128132 progress_label = "[green]tasks" ,
@@ -164,12 +168,16 @@ def get(self, dsk, keys, *,
164168 else :
165169 self .lib_modules = hoisting_modules if hoisting_modules else import_modules # Deprecated
166170 self .task_mode = task_mode
171+ self .scheduling_mode = scheduling_mode
167172 self .env_per_task = env_per_task
168173 self .progress_disable = progress_disable
169174 self .progress_label = progress_label
170175 self .wrapper = wrapper
171176 self .wrapper_proc = wrapper_proc
172177 self .prune_files = prune_files
178+ self .category_info = defaultdict (lambda : {"num_tasks" : 0 , "total_execution_time" : 0 })
179+ self .max_priority = float ('inf' )
180+ self .min_priority = float ('-inf' )
173181
174182 if submit_per_cycle is not None and submit_per_cycle < 1 :
175183 submit_per_cycle = None
@@ -274,6 +282,8 @@ def _dask_execute(self, dsk, keys):
274282 print (f"{ t .key } ran on { t .hostname } " )
275283
276284 if t .successful ():
285+ self .category_info [t .category ]["num_tasks" ] += 1
286+ self .category_info [t .category ]["total_execution_time" ] += t .resources_measured .wall_time
277287 result_file = DaskVineFile (t .output_file , t .key , dag , self .task_mode )
278288 rs = dag .set_result (t .key , result_file )
279289 self ._enqueue_dask_calls (dag , tag , rs , self .retries , enqueued_calls )
@@ -335,7 +345,42 @@ def _enqueue_dask_calls(self, dag, tag, rs, retries, enqueued_calls):
335345 if lazy and self .checkpoint_fn :
336346 lazy = self .checkpoint_fn (dag , k )
337347
348+ # each task has a category name
338349 cat = self .category_name (sexpr )
350+
351+ task_depth = dag .depth_of (k )
352+ if self .scheduling_mode == 'random' :
353+ priority = random .randint (self .min_priority , self .max_priority )
354+ elif self .scheduling_mode == 'depth-first' :
355+ # dig more information about different kinds of tasks
356+ priority = task_depth
357+ elif self .scheduling_mode == 'breadth-first' :
358+ # prefer to start all branches as soon as possible
359+ priority = - task_depth
360+ elif self .scheduling_mode == 'longest-category-first' :
361+ # if no tasks have been executed in this category, set a high priority so that we know more information about each category
362+ if self .category_info [cat ]["num_tasks" ]:
363+ priority = self .category_info [cat ]["total_execution_time" ] / self .category_info [cat ]["num_tasks" ]
364+ else :
365+ priority = self .max_priority
366+ elif self .scheduling_mode == 'shortest-category-first' :
367+ # if no tasks have been executed in this category, set a high priority so that we know more information about each category
368+ if self .category_info [cat ]["num_tasks" ]:
369+ priority = - self .category_info [cat ]["total_execution_time" ] / self .category_info [cat ]["num_tasks" ]
370+ else :
371+ priority = self .max_priority
372+ elif self .scheduling_mode == 'FIFO' :
373+ # first in first out, the default behavior
374+ priority = - round (time .time (), 6 )
375+ elif self .scheduling_mode == 'LIFO' :
376+ # last in first out, the opposite of FIFO
377+ priority = round (time .time (), 6 )
378+ elif self .scheduling_mode == 'largest-input-first' :
379+ # best for saving disk space (with pruing)
380+ priority = sum ([len (dag .get_result (c )._file ) for c in dag .get_children (k )])
381+ else :
382+ raise ValueError (f"Unknown scheduling mode { self .scheduling_mode } " )
383+
339384 if self .task_mode == 'tasks' :
340385 if cat not in self ._categories_known :
341386 if self .resources :
@@ -357,6 +402,7 @@ def _enqueue_dask_calls(self, dag, tag, rs, retries, enqueued_calls):
357402 worker_transfers = lazy ,
358403 wrapper = self .wrapper )
359404
405+ t .set_priority (priority )
360406 if self .env_per_task :
361407 t .set_command (
362408 f"mkdir envdir && tar -xf { self ._environment_name } -C envdir && envdir/bin/run_in_env { t ._command } " )
@@ -374,6 +420,7 @@ def _enqueue_dask_calls(self, dag, tag, rs, retries, enqueued_calls):
374420 worker_transfers = lazy ,
375421 wrapper = self .wrapper )
376422
423+ t .set_priority (priority )
377424 t .set_tag (tag ) # tag that identifies this dag
378425
379426 enqueued_calls .append (t )
@@ -632,6 +679,7 @@ def __init__(self, m,
632679 self .set_category (category )
633680 if worker_transfers :
634681 self .enable_temp_output ()
682+
635683 if extra_files :
636684 for f , name in extra_files .items ():
637685 self .add_input (f , name )
0 commit comments