diff --git a/changelog.md b/changelog.md index ace0426e..b8695b57 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,11 @@ +Upcoming (TBD) +============== + +Features +--------- +* Add a `--progress` progress-bar option with `--batch`. + + 1.66.0 (2026/03/21) ============== diff --git a/mycli/main.py b/mycli/main.py index d5f2b403..bc2573af 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -33,6 +33,7 @@ import click from configobj import ConfigObj import keyring +import prompt_toolkit from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest @@ -53,7 +54,8 @@ from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.output import ColorDepth -from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +from prompt_toolkit.shortcuts import CompleteStyle, ProgressBar, PromptSession +from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters import pymysql from pymysql.constants.CR import CR_SERVER_LOST from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR @@ -2000,6 +2002,7 @@ def get_last_query(self) -> str | None: ) @click.argument("database", default=None, nargs=1) @click.option('--batch', 'batch_file', type=str, help='SQL script to execute in batch mode.') +@click.option('--progress', 'batch_progress_bar', is_flag=True, help='Show progress with --batch.') @click.option("--noninteractive", is_flag=True, help="Don't prompt during batch input. Recommended.") @click.option( '--format', 'batch_format', type=click.Choice(['default', 'csv', 'tsv', 'table']), help='Format for batch or --execute output.' @@ -2070,6 +2073,7 @@ def click_entrypoint( password_file: str | None, noninteractive: bool, batch_file: str | None, + batch_progress_bar: str | None, batch_format: str | None, throttle: float, use_keyring_cli_opt: str | None, @@ -2572,17 +2576,70 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: click.secho(str(e), err=True, fg="red") sys.exit(1) - if batch_file or not sys.stdin.isatty(): - if batch_file: - if not sys.stdin.isatty() and batch_file != '-': - click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') - try: - batch_h = click.open_file(batch_file) - except (OSError, FileNotFoundError): - click.secho(f'Failed to open --batch file: {batch_file}', err=True, fg='red') - sys.exit(1) - else: - batch_h = click.get_text_stream('stdin') + if batch_file and batch_file != '-' and batch_progress_bar and sys.stderr.isatty(): + # The actual number of SQL statements can be greater, if there is more than + # one statement per line, but this is how the progress bar will count. + goal_statements = 0 + if not sys.stdin.isatty() and batch_file != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='yellow') + if os.path.exists(batch_file) and not os.path.isfile(batch_file): + click.secho('--progress is only compatible with a plain file.', err=True, fg='red') + sys.exit(1) + try: + batch_count_h = click.open_file(batch_file) + for _statement, _counter in statements_from_filehandle(batch_count_h): + goal_statements += 1 + batch_count_h.close() + batch_h = click.open_file(batch_file) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {batch_file}', err=True, fg='red') + sys.exit(1) + except ValueError as e: + click.secho(f'Error reading --batch file: {batch_file}: {e}', err=True, fg='red') + sys.exit(1) + try: + if goal_statements: + pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) + custom_formatters = [ + progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Progress(), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Text('eta ', style='class:time-left'), + progress_bar_formatters.TimeLeft(), + progress_bar_formatters.Text(' ', style='class:time-left'), + ] + err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True) + with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: + for pb_counter in pb(range(goal_statements)): + statement, _untrusted_counter = next(statements_from_filehandle(batch_h)) + dispatch_batch_statements(statement, pb_counter) + except (ValueError, StopIteration) as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + finally: + batch_h.close() + sys.exit(0) + + if batch_file: + if not sys.stdin.isatty() and batch_file != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') + try: + batch_h = click.open_file(batch_file) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {batch_file}', err=True, fg='red') + sys.exit(1) + try: + for statement, counter in statements_from_filehandle(batch_h): + dispatch_batch_statements(statement, counter) + batch_h.close() + except ValueError as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + sys.exit(0) + + if not sys.stdin.isatty(): + batch_h = click.get_text_stream('stdin') try: for statement, counter in statements_from_filehandle(batch_h): dispatch_batch_statements(statement, counter) diff --git a/test/test_main.py b/test/test_main.py index f47e5beb..e800ec85 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -6,8 +6,10 @@ import io import os import shutil +import sys from tempfile import NamedTemporaryFile from textwrap import dedent +from types import SimpleNamespace import click from click.testing import CliRunner @@ -1465,6 +1467,79 @@ def test_batch_file(monkeypatch): os.remove(batch_file.name) +def test_batch_file_with_progress(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + class DummyProgressBar: + calls = [] + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def __call__(self, iterable): + values = list(iterable) + DummyProgressBar.calls.append(values) + return values + + monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object()) + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2;\nselect 2;\nselect 2;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name, '--progress'], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n'] + assert DummyProgressBar.calls == [[0, 1, 2]] + finally: + os.remove(batch_file.name) + + +def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', str(tmp_path), '--progress'], + ) + + assert result.exit_code != 0 + assert '--progress is only compatible with a plain file.' in result.output + assert MockMyCli.ran_queries == [] + + def test_execute_arg_warns_about_ignoring_stdin(monkeypatch): mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) runner = CliRunner()