diff --git a/scripts/test.py b/scripts/test.py index 116164b9..4fca2c89 100755 --- a/scripts/test.py +++ b/scripts/test.py @@ -42,8 +42,14 @@ cfg.data_dir = os.path.join(training_dir, 'data') if cfg.results_dir is None: cfg.results_dir = os.path.join(training_dir, 'results') + if cfg.baseline_dir is None: - cfg.baseline_dir = os.path.join(training_dir, 'baseline_' + BASELINE_VERSION) + baseline_img_repo = os.environ.get('OIDN_VALIDATION_REPO') + if baseline_img_repo is not None: + run(f'git clone --branch {BASELINE_VERSION} --single-branch --depth 1 "{baseline_img_repo}" baseline_images_repo') + cfg.baseline_dir = os.path.join('baseline_images_repo', 'images') + else: + cfg.baseline_dir = os.path.join(training_dir, 'baseline_' + BASELINE_VERSION) if cfg.command == 'run': # Detect the OIDN binary directory