Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import fnmatch
import logging
import random
import typing
from dataclasses import dataclass, field
from typing import Literal, Optional
Expand Down Expand Up @@ -138,6 +139,28 @@ def subset_from_regexp(self, column, regexp):
task_metadata=self.task_metadata,
)

def subset_from_task_ratio(self, ratio, seed):
"""Get a random subset of the tasks given a ratio and seed."""
rng = random.Random(seed)
task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
rng.shuffle(task_names)
num_tasks = int(len(task_names) * ratio)
task_name_subset = task_names[:num_tasks]

return Benchmark(
name=f"{self.name}[ratio={ratio}, seed={seed}]",
high_level_action_set_args=self.high_level_action_set_args,
is_multi_tab=self.is_multi_tab,
supports_parallel_seeds=self.supports_parallel_seeds,
backends=self.backends,
env_args_list=[
env_args
for env_args in self.env_args_list
if env_args.task_name in task_name_subset
],
task_metadata=self.task_metadata,
)

def dependency_graph_over_tasks(self) -> dict[str, list[str]]:
# recover all unique task_names present in the benchmark
task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
Expand Down
38 changes: 38 additions & 0 deletions tests/experiments/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import os
import random
import re
import tempfile

Expand Down Expand Up @@ -92,6 +93,43 @@ def test_benchmark_subset():
assert dict_1 == dict_2


def test_benchmark_subset_from_task_ratio():
benchmark: Benchmark = DEFAULT_BENCHMARKS["webarena"]()

# Store initial random state
initial_state = random.getstate()

benchmark_subset = benchmark.subset_from_task_ratio(ratio=0.5, seed=1)
assert len(benchmark_subset.env_args_list) == 812 // 2
assert benchmark_subset.name == "webarena[ratio=0.5, seed=1]"

# Verify global random state hasn't changed
assert random.getstate() == initial_state

benchmark_subset_1 = benchmark_subset.subset_from_task_ratio(ratio=0.5, seed=1)
benchmark_subset_2 = benchmark_subset.subset_from_task_ratio(ratio=0.5, seed=2)

# Verify global random state still hasn't changed
assert random.getstate() == initial_state

# Check the task lists are different
assert not np.all(
[
env_args.task_name == env_args_2.task_name
for env_args, env_args_2 in zip(
benchmark_subset_1.env_args_list, benchmark_subset_2.env_args_list
)
]
)

dict_1 = benchmark_subset_1.to_dict()
dict_1.pop("name")
dict_2 = benchmark_subset_2.to_dict()
dict_2.pop("name")
assert len(dict_1["env_args_list"]) == len(dict_2["env_args_list"])
assert dict_1 != dict_2


def test_prepare_backend_miniwob():
MINIWOB_URL = os.environ["MINIWOB_URL"]
try:
Expand Down
Loading