File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -173,6 +173,8 @@ def _get_split_indices(
173173 valid_idx = indices [train_size : train_size + valid_size ]
174174
175175 if save_indices :
176+ if indices_dir != '.' :
177+ os .makedirs (indices_dir , exist_ok = True )
176178 np .savetxt (f"{ indices_dir } /train_indices.txt" , train_idx .numpy (), fmt = "%d" )
177179 np .savetxt (f"{ indices_dir } /valid_indices.txt" , valid_idx .numpy (), fmt = "%d" )
178180
@@ -353,7 +355,8 @@ def get_dataloader(
353355 validation_split = 0.1 ,
354356 pin_memory = True ,
355357 manual_seed = None ,
356- save_indices = False
358+ save_indices = False ,
359+ indices_dir = '.'
357360 ):
358361 """
359362 :param int batch_size: size of the training batches
@@ -371,7 +374,7 @@ def get_dataloader(
371374 train_size = None ,
372375 manual_seed = manual_seed ,
373376 save_indices = save_indices ,
374- indices_dir = '.'
377+ indices_dir = indices_dir
375378 )
376379
377380 tensor_dataset = torch .utils .data .TensorDataset (dataset )
You can’t perform that action at this time.
0 commit comments