Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ python butler.py format

This will format the changed code in your current branch.
It's possible to get into a state where linting and formatting contradict each other. In this case, STOP, the human will fix it.

## Codebase Notes

### Batch Logic

- `src/clusterfuzz/_internal/batch/gcp.py` contains low-level GCP Batch client logic. `check_congestion_jobs` is placed here as it directly queries job status using the client.
- `src/clusterfuzz/_internal/batch/service.py` contains high-level batch service logic, including configuration management. `create_congestion_job` is placed here because it depends on configuration logic (`_get_specs_from_config`, etc.).
- `src/clusterfuzz/_internal/google_cloud_utils/batch.py` acts as a facade/wrapper for backward compatibility or convenience, delegating to `gcp.py` and `service.py`.
35 changes: 32 additions & 3 deletions src/clusterfuzz/_internal/batch/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def get_job_name():
return 'j-' + str(uuid.uuid4()).lower()


def _get_task_spec(batch_workload_spec):
def _get_task_spec(batch_workload_spec, commands=None):
"""Gets the task spec based on the batch workload spec."""
runnable = batch.Runnable()
runnable.container = batch.Runnable.Container()
runnable.container.image_uri = batch_workload_spec.docker_image
if commands:
runnable.container.commands = commands
clusterfuzz_release = batch_workload_spec.clusterfuzz_release
runnable.container.options = (
'--memory-swappiness=40 --shm-size=1.9g --rm --net=host '
Expand Down Expand Up @@ -152,6 +154,29 @@ def count_queued_or_scheduled_tasks(project: str,
return (queued, scheduled)


def check_congestion_jobs(job_ids: List[str]) -> int:
"""Checks the status of the congestion jobs."""
completed_count = 0
for job_id in job_ids:
try:
job = _batch_client().get_job(name=job_id)
# We count SUCCEEDED, RUNNING, and FAILED as completed (i.e. not
# congested). If the job is in any of these states, it means it was
# successfully scheduled and started running. If it is QUEUED, it means
# it is still waiting to be scheduled, which implies congestion.
if job.status.state in (batch.JobStatus.State.SUCCEEDED,
batch.JobStatus.State.RUNNING,
batch.JobStatus.State.FAILED):
completed_count += 1
except Exception:
# If we can't get the job, it might have been deleted or there is an
# error.
# We don't count it as completed.
logs.warning(f'Failed to get job {job_id}.')

return completed_count


class GcpBatchClient(RemoteTaskInterface):
"""A client for creating and managing jobs on the GCP Batch service.

Expand All @@ -161,7 +186,7 @@ class GcpBatchClient(RemoteTaskInterface):
specification, which defines the container image and command to run.
"""

def create_job(self, spec, input_urls: List[str]):
def create_job(self, spec, input_urls: List[str], commands=None):
"""Creates and starts a batch job from |spec| that executes all tasks.

This method creates a new GCP Batch job with a single task group. The
Expand All @@ -177,7 +202,7 @@ def create_job(self, spec, input_urls: List[str]):
for input_url in input_urls
]
task_group.task_environments = task_environments
task_group.task_spec = _get_task_spec(spec)
task_group.task_spec = _get_task_spec(spec, commands=commands)
task_group.task_count_per_node = TASK_COUNT_PER_NODE
assert task_group.task_count_per_node == 1, 'This is a security issue'

Expand All @@ -198,3 +223,7 @@ def create_job(self, spec, input_urls: List[str]):
job_result = _send_create_job_request(create_request)
logs.info(f'Created batch job id={job_name}.', spec=spec)
return job_result

def get_job(self, name):
"""Gets a batch job."""
return _batch_client().get_job(name=name)
30 changes: 30 additions & 0 deletions src/clusterfuzz/_internal/batch/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import Dict
from typing import List

from google.cloud import batch_v1 as batch

from clusterfuzz._internal.base import tasks
from clusterfuzz._internal.base import utils
from clusterfuzz._internal.base.tasks import task_utils
Expand Down Expand Up @@ -114,6 +116,16 @@ def _get_subconfig(batch_config, instance_spec):
return all_subconfigs[weighted_subconfig.name]


def _get_subconfig_for_region(batch_config, instance_spec, region):
all_subconfigs = batch_config.get('subconfigs', {})
instance_subconfigs = instance_spec['subconfigs']
for subconfig in instance_subconfigs:
full_subconfig = all_subconfigs[subconfig['name']]
if full_subconfig['region'] == region:
return full_subconfig
raise ValueError(f'No subconfig for region {region}')


def _get_specs_from_config(batch_tasks) -> Dict:
"""Gets the configured specifications for a batch workload."""
if not batch_tasks:
Expand Down Expand Up @@ -227,3 +239,21 @@ def create_uworker_main_batch_jobs(self, batch_tasks: List[BatchTask]):
jobs.append(self._client.create_job(spec, input_urls_portion))

return jobs

def create_congestion_job(self, job_type, gce_region=None):
"""Creates a congestion job."""
batch_tasks = [BatchTask('fuzz', job_type, 'CONGESTION')]
specs = _get_specs_from_config(batch_tasks)
spec = specs[('fuzz', job_type)]
if gce_region:
batch_config = _get_batch_config()
config_map = _get_config_names(batch_tasks)
config_name, _, _ = config_map[('fuzz', job_type)]
instance_spec = batch_config.get('mapping').get(config_name)
subconfig = _get_subconfig_for_region(batch_config, instance_spec,
gce_region)
spec.gce_region = subconfig['region']
spec.network = subconfig['network']
spec.subnetwork = subconfig['subnetwork']

return self._client.create_job(spec, ['CONGESTION'], commands=['echo', 'hello'])
102 changes: 87 additions & 15 deletions src/clusterfuzz/_internal/cron/schedule_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Cron job to schedule fuzz tasks that run on batch."""

import collections
import datetime
import multiprocessing
import random
import time
Expand All @@ -25,10 +26,10 @@

from clusterfuzz._internal.base import tasks
from clusterfuzz._internal.base import utils
from clusterfuzz._internal.batch import gcp as batch
from clusterfuzz._internal.config import local_config
from clusterfuzz._internal.datastore import data_types
from clusterfuzz._internal.datastore import ndb_utils
from clusterfuzz._internal.google_cloud_utils import batch
from clusterfuzz._internal.google_cloud_utils import credentials
from clusterfuzz._internal.metrics import logs

Expand Down Expand Up @@ -295,7 +296,13 @@ def get_fuzz_tasks(self) -> List[tasks.Task]:
return fuzz_tasks


def get_fuzz_tasks(available_cpus: int) -> [tasks.Task]:
def get_fuzz_tasks(project, regions) -> [tasks.Task]:
available_cpus = get_available_cpus(project, regions)
logs.info(f'{available_cpus} available CPUs.')

if not available_cpus:
return []

if utils.is_oss_fuzz():
scheduler = OssfuzzFuzzTaskScheduler(available_cpus)
else:
Expand Down Expand Up @@ -377,26 +384,91 @@ def respect_project_max_cpus(num_cpus):
return min(max_cpus_per_schedule, num_cpus)


def get_congested_regions() -> List[str]:
"""Returns a list of congested regions. The strategy used is as follows:
Run congestion jobs every time this cron is run in each region.
Assuming we run this cron more than 3 times an hour, if there aren't
3 completed jobs in the last hour, they either failed (unlikely, they are
trivial) or never ran because of congestion.
"""
one_hour_ago = datetime.datetime.utcnow() - datetime.timedelta(hours=1)
congestion_jobs = list(
data_types.CongestionJob.query(
data_types.CongestionJob.timestamp > one_hour_ago))

jobs_by_region = collections.defaultdict(list)
for job in congestion_jobs:
if job.region:
jobs_by_region[job.region].append(job)

congested_regions = []
for region, jobs in jobs_by_region.items():
# Sort by timestamp descending.
jobs.sort(key=lambda j: j.timestamp, reverse=True)
# Check the last 3 jobs.
recent_jobs = jobs[:3]
if len(recent_jobs) < 3:
continue

completed_count = batch.check_congestion_jobs(
[job.job_id for job in recent_jobs])
if completed_count < 3:
# TODO(metzman): Add some monitoring here.
logs.error(f'Congestion detected in {region}: {completed_count}/3 '
'congestion jobs completed in the last hour.')
congested_regions.append(region)
return congested_regions


def schedule_congestion_jobs(fuzz_tasks, all_regions):
"""Schedules congestion jobs for all regions."""
# Run a hello world task that finishes very quickly. The job field is
# ignored, but we need one, so pick an arbitrary one.
clusterfuzz_job_type = None
if fuzz_tasks:
clusterfuzz_job_type = fuzz_tasks[0].job
else:
# If no tasks scheduled, try to get a job type from DB to run congestion
# job.
job = data_types.Job.query().get()
if job:
clusterfuzz_job_type = job.name

if clusterfuzz_job_type:
for region in all_regions:
try:
batch_job_result = batch.create_congestion_job(
clusterfuzz_job_type, gce_region=region)
data_types.CongestionJob(
job_id=batch_job_result.name, region=region).put()
except Exception:
logs.error(f'Failed to create congestion job in {region}.')


def schedule_fuzz_tasks() -> bool:
"""Schedules fuzz tasks."""
multiprocessing.set_start_method('spawn')
try:
multiprocessing.set_start_method('spawn')
except RuntimeError: # Ignore if this was done previously.
pass

batch_config = local_config.BatchConfig()
project = batch_config.get('project')
regions = get_batch_regions(batch_config)
all_regions = get_batch_regions(batch_config)
congested_regions = get_congested_regions()
regions = [r for r in all_regions if r not in congested_regions]

start = time.time()
available_cpus = get_available_cpus(project, regions)
logs.info(f'{available_cpus} available CPUs.')
if not available_cpus:
return False
fuzz_tasks = get_fuzz_tasks(project, regions)

fuzz_tasks = get_fuzz_tasks(available_cpus)
if not fuzz_tasks:
logs.error('No fuzz tasks found to schedule.')
return False
if fuzz_tasks:
logs.info(f'Adding {fuzz_tasks} to preprocess queue.')
tasks.bulk_add_tasks(fuzz_tasks, queue=tasks.PREPROCESS_QUEUE, eta_now=True) # pylint: disable=line-too-long
logs.info(f'Scheduled {len(fuzz_tasks)} fuzz tasks.')
else:
logs.info('No fuzz tasks scheduled.')

logs.info(f'Adding {fuzz_tasks} to preprocess queue.')
tasks.bulk_add_tasks(fuzz_tasks, queue=tasks.PREPROCESS_QUEUE, eta_now=True)
logs.info(f'Scheduled {len(fuzz_tasks)} fuzz tasks.')
schedule_congestion_jobs(fuzz_tasks, all_regions)

end = time.time()
total = end - start
Expand Down
18 changes: 18 additions & 0 deletions src/clusterfuzz/_internal/datastore/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,3 +1803,21 @@ class FuzzerTaskEvent(Model):
def _pre_put_hook(self):
self.ttl_expiry_timestamp = (
datetime.datetime.now() + self.FUZZER_EVENT_TTL)


class CongestionJob(Model):
"""Congestion job. Used to measure congestion in batch."""
CONGESTION_JOB_TTL = datetime.timedelta(days=3)

# The job name (ID) in Batch.
job_id = ndb.StringProperty()
# The region the job is running in.
region = ndb.StringProperty()
# Time of creation.
timestamp = ndb.DateTimeProperty(auto_now_add=True)
# Expiration time for this entity.
ttl_expiry_timestamp = ndb.DateTimeProperty()

def _pre_put_hook(self):
self.ttl_expiry_timestamp = (
datetime.datetime.now() + self.CONGESTION_JOB_TTL)
25 changes: 25 additions & 0 deletions src/clusterfuzz/_internal/google_cloud_utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,28 @@ def create_uworker_main_batch_jobs(batch_tasks: List[BatchTask]):
"""Creates batch jobs."""
service = BatchService()
return service.create_uworker_main_batch_jobs(batch_tasks)


def create_congestion_job(job_type, gce_region=None):
"""Creates a congestion job."""
service = BatchService()
return service.create_congestion_job(job_type, gce_region)


def check_congestion_jobs(job_ids):
"""Checks the status of the congestion jobs."""
from clusterfuzz._internal.batch import gcp
return gcp.check_congestion_jobs(job_ids)


def count_queued_or_scheduled_tasks(project: str, region: str):
# TODO(metzman): Move this to BatchService too if needed, but for now
# we can import it or reimplement.
# The master branch version of batch.py didn't seem to have this?
# Let's check if it was removed or if I should keep it.
# It was in my HEAD version.
# It uses `_batch_client` which is removed from here.
# It should probably be in BatchService or GcpBatchClient.
# I'll check gcp.py, it has it.
from clusterfuzz._internal.batch import gcp
return gcp.count_queued_or_scheduled_tasks(project, region)
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,73 @@ def test_config_limit(self):
}]
self.assertEqual(
schedule_fuzz.get_cpu_usage(self.creds, 'region', 'project'), (2, 0))


@test_utils.with_cloud_emulators('datastore')
class ScheduleFuzzTasksTest(unittest.TestCase):
"""Tests for schedule_fuzz_tasks."""

def setUp(self):
test_helpers.patch(self, [
'clusterfuzz._internal.cron.schedule_fuzz.get_available_cpus',
'clusterfuzz._internal.cron.schedule_fuzz.get_fuzz_tasks',
'clusterfuzz._internal.cron.schedule_fuzz.get_batch_regions',
'clusterfuzz._internal.google_cloud_utils.batch.check_congestion_jobs',
'clusterfuzz._internal.google_cloud_utils.batch.create_congestion_job',
'clusterfuzz._internal.base.tasks.bulk_add_tasks',
])
self.mock.get_batch_regions.return_value = ['us-central1']
mock_job = unittest.mock.Mock()
mock_job.name = 'congestion-job'
self.mock.create_congestion_job.return_value = mock_job

def test_is_congested_true(self):
"""Tests that scheduling stops when congested."""
# Create 3 congestion jobs.
for i in range(3):
data_types.CongestionJob(job_id=f'job-{i}', region='us-central1').put()

# Mock check_congestion_jobs to return 0 completed.
self.mock.check_congestion_jobs.return_value = 0
# Mock get_fuzz_tasks to return empty list (simulating no CPUs or other issues)
self.mock.get_fuzz_tasks.return_value = []

self.assertTrue(schedule_fuzz.schedule_fuzz_tasks())

self.mock.get_fuzz_tasks.assert_called()
# Verify called with empty regions list (2nd argument)
# Args are (project, regions)
call_args = self.mock.get_fuzz_tasks.call_args[0]
self.assertEqual(call_args[1], [])

def test_is_congested_false(self):
"""Tests that scheduling proceeds when not congested."""
# Create 3 congestion jobs.
for i in range(3):
data_types.CongestionJob(job_id=f'job-{i}', region='us-central1').put()

# Mock check_congestion_jobs to return 3 completed.
self.mock.check_congestion_jobs.return_value = 3
mock_task = unittest.mock.Mock()
mock_task.job = 'job1'
self.mock.get_fuzz_tasks.return_value = [mock_task]

self.assertTrue(schedule_fuzz.schedule_fuzz_tasks())
self.mock.get_fuzz_tasks.assert_called()

def test_no_congestion_job_if_no_tasks(self):
"""Tests that no congestion job is scheduled if no fuzz tasks."""
self.mock.get_fuzz_tasks.return_value = []

self.assertTrue(schedule_fuzz.schedule_fuzz_tasks())
self.mock.create_congestion_job.assert_not_called()

def test_congestion_job_scheduled(self):
"""Tests that a congestion job is scheduled when fuzz tasks are."""
mock_task = unittest.mock.Mock()
mock_task.job = 'job1'
self.mock.get_fuzz_tasks.return_value = [mock_task]

self.assertTrue(schedule_fuzz.schedule_fuzz_tasks())
self.mock.create_congestion_job.assert_called_with(
'job1', gce_region='us-central1')
Loading
Loading