Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions rdagent/app/data_science/loop.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
import asyncio
from collections.abc import Coroutine
from pathlib import Path
from typing import Optional
from typing import Annotated, Optional

import fire
import typer
from typing_extensions import Annotated

from rdagent.app.data_science.conf import DS_RD_SETTING
from rdagent.core.utils import import_class
from rdagent.log import rdagent_logger as logger
from rdagent.scenarios.data_science.loop import DataScienceRDLoop


async def run_and_submit_sota(loop_task: Coroutine, competition: str) -> None:
"""Run the loop coroutine task, and submit the SOTA experiment submission file to kaggle at the end."""
from rdagent.scenarios.kaggle.submission import submit_current_sota

try:
# wait for the loop end
await loop_task
finally:
# we do not care about exception, just make sure we can submit
submit_current_sota(competition=competition)


def main(
path: Optional[str] = None,
checkout: Annotated[bool, typer.Option("--checkout/--no-checkout", "-c/-C")] = True,
Expand Down Expand Up @@ -73,7 +84,12 @@ def main(
if exp_gen_cls is not None:
kaggle_loop.exp_gen = import_class(exp_gen_cls)(kaggle_loop.exp_gen.scen)

asyncio.run(kaggle_loop.run(step_n=step_n, loop_n=loop_n, all_duration=timeout))
if DS_RD_SETTING.auto_submit:
asyncio.run(
run_and_submit_sota(kaggle_loop.run(step_n=step_n, loop_n=loop_n, all_duration=timeout), competition)
)
else:
asyncio.run(kaggle_loop.run(step_n=step_n, loop_n=loop_n, all_duration=timeout))


if __name__ == "__main__":
Expand Down
Loading