Skip to content

Commit 3c37c03

Browse files
Create directory if needed (#43)
1 parent d5c6c12 commit 3c37c03

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/molearn/data/pdb_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def _get_split_indices(
173173
valid_idx = indices[train_size : train_size + valid_size]
174174

175175
if save_indices:
176+
if indices_dir != '.':
177+
os.makedirs(indices_dir, exist_ok=True)
176178
np.savetxt(f"{indices_dir}/train_indices.txt", train_idx.numpy(), fmt="%d")
177179
np.savetxt(f"{indices_dir}/valid_indices.txt", valid_idx.numpy(), fmt="%d")
178180

@@ -353,7 +355,8 @@ def get_dataloader(
353355
validation_split=0.1,
354356
pin_memory=True,
355357
manual_seed=None,
356-
save_indices=False
358+
save_indices=False,
359+
indices_dir='.'
357360
):
358361
"""
359362
:param int batch_size: size of the training batches
@@ -371,7 +374,7 @@ def get_dataloader(
371374
train_size=None,
372375
manual_seed=manual_seed,
373376
save_indices=save_indices,
374-
indices_dir='.'
377+
indices_dir=indices_dir
375378
)
376379

377380
tensor_dataset = torch.utils.data.TensorDataset(dataset)

0 commit comments

Comments
 (0)