diff --git a/AGENTS.md b/AGENTS.md index a53e6d72b8..15460ef884 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -93,3 +93,11 @@ python butler.py format This will format the changed code in your current branch. It's possible to get into a state where linting and formatting contradict each other. In this case, STOP, the human will fix it. + +## Codebase Notes + +### Batch Logic + +- `src/clusterfuzz/_internal/batch/gcp.py` contains low-level GCP Batch client logic. `check_congestion_jobs` is placed here as it directly queries job status using the client. +- `src/clusterfuzz/_internal/batch/service.py` contains high-level batch service logic, including configuration management. `create_congestion_job` is placed here because it depends on configuration logic (`_get_specs_from_config`, etc.). +- `src/clusterfuzz/_internal/google_cloud_utils/batch.py` acts as a facade/wrapper for backward compatibility or convenience, delegating to `gcp.py` and `service.py`. \ No newline at end of file diff --git a/src/clusterfuzz/_internal/batch/gcp.py b/src/clusterfuzz/_internal/batch/gcp.py index 9b155cce7b..583653e743 100644 --- a/src/clusterfuzz/_internal/batch/gcp.py +++ b/src/clusterfuzz/_internal/batch/gcp.py @@ -61,11 +61,13 @@ def get_job_name(): return 'j-' + str(uuid.uuid4()).lower() -def _get_task_spec(batch_workload_spec): +def _get_task_spec(batch_workload_spec, commands=None): """Gets the task spec based on the batch workload spec.""" runnable = batch.Runnable() runnable.container = batch.Runnable.Container() runnable.container.image_uri = batch_workload_spec.docker_image + if commands: + runnable.container.commands = commands clusterfuzz_release = batch_workload_spec.clusterfuzz_release runnable.container.options = ( '--memory-swappiness=40 --shm-size=1.9g --rm --net=host ' @@ -152,6 +154,29 @@ def count_queued_or_scheduled_tasks(project: str, return (queued, scheduled) +def check_congestion_jobs(job_ids: List[str]) -> int: + """Checks the status of the congestion jobs.""" + completed_count = 0 + for job_id in job_ids: + try: + job = _batch_client().get_job(name=job_id) + # We count SUCCEEDED, RUNNING, and FAILED as completed (i.e. not + # congested). If the job is in any of these states, it means it was + # successfully scheduled and started running. If it is QUEUED, it means + # it is still waiting to be scheduled, which implies congestion. + if job.status.state in (batch.JobStatus.State.SUCCEEDED, + batch.JobStatus.State.RUNNING, + batch.JobStatus.State.FAILED): + completed_count += 1 + except Exception: + # If we can't get the job, it might have been deleted or there is an + # error. + # We don't count it as completed. + logs.warning(f'Failed to get job {job_id}.') + + return completed_count + + class GcpBatchClient(RemoteTaskInterface): """A client for creating and managing jobs on the GCP Batch service. @@ -161,7 +186,7 @@ class GcpBatchClient(RemoteTaskInterface): specification, which defines the container image and command to run. """ - def create_job(self, spec, input_urls: List[str]): + def create_job(self, spec, input_urls: List[str], commands=None): """Creates and starts a batch job from |spec| that executes all tasks. This method creates a new GCP Batch job with a single task group. The @@ -177,7 +202,7 @@ def create_job(self, spec, input_urls: List[str]): for input_url in input_urls ] task_group.task_environments = task_environments - task_group.task_spec = _get_task_spec(spec) + task_group.task_spec = _get_task_spec(spec, commands=commands) task_group.task_count_per_node = TASK_COUNT_PER_NODE assert task_group.task_count_per_node == 1, 'This is a security issue' @@ -198,3 +223,7 @@ def create_job(self, spec, input_urls: List[str]): job_result = _send_create_job_request(create_request) logs.info(f'Created batch job id={job_name}.', spec=spec) return job_result + + def get_job(self, name): + """Gets a batch job.""" + return _batch_client().get_job(name=name) diff --git a/src/clusterfuzz/_internal/batch/service.py b/src/clusterfuzz/_internal/batch/service.py index be16cb6da3..7024b02bd1 100644 --- a/src/clusterfuzz/_internal/batch/service.py +++ b/src/clusterfuzz/_internal/batch/service.py @@ -21,6 +21,8 @@ from typing import Dict from typing import List +from google.cloud import batch_v1 as batch + from clusterfuzz._internal.base import tasks from clusterfuzz._internal.base import utils from clusterfuzz._internal.base.tasks import task_utils @@ -114,6 +116,16 @@ def _get_subconfig(batch_config, instance_spec): return all_subconfigs[weighted_subconfig.name] +def _get_subconfig_for_region(batch_config, instance_spec, region): + all_subconfigs = batch_config.get('subconfigs', {}) + instance_subconfigs = instance_spec['subconfigs'] + for subconfig in instance_subconfigs: + full_subconfig = all_subconfigs[subconfig['name']] + if full_subconfig['region'] == region: + return full_subconfig + raise ValueError(f'No subconfig for region {region}') + + def _get_specs_from_config(batch_tasks) -> Dict: """Gets the configured specifications for a batch workload.""" if not batch_tasks: @@ -227,3 +239,21 @@ def create_uworker_main_batch_jobs(self, batch_tasks: List[BatchTask]): jobs.append(self._client.create_job(spec, input_urls_portion)) return jobs + + def create_congestion_job(self, job_type, gce_region=None): + """Creates a congestion job.""" + batch_tasks = [BatchTask('fuzz', job_type, 'CONGESTION')] + specs = _get_specs_from_config(batch_tasks) + spec = specs[('fuzz', job_type)] + if gce_region: + batch_config = _get_batch_config() + config_map = _get_config_names(batch_tasks) + config_name, _, _ = config_map[('fuzz', job_type)] + instance_spec = batch_config.get('mapping').get(config_name) + subconfig = _get_subconfig_for_region(batch_config, instance_spec, + gce_region) + spec.gce_region = subconfig['region'] + spec.network = subconfig['network'] + spec.subnetwork = subconfig['subnetwork'] + + return self._client.create_job(spec, ['CONGESTION'], commands=['echo', 'hello']) diff --git a/src/clusterfuzz/_internal/cron/schedule_fuzz.py b/src/clusterfuzz/_internal/cron/schedule_fuzz.py index 0a29697072..bd1d4231ef 100644 --- a/src/clusterfuzz/_internal/cron/schedule_fuzz.py +++ b/src/clusterfuzz/_internal/cron/schedule_fuzz.py @@ -14,6 +14,7 @@ """Cron job to schedule fuzz tasks that run on batch.""" import collections +import datetime import multiprocessing import random import time @@ -25,10 +26,10 @@ from clusterfuzz._internal.base import tasks from clusterfuzz._internal.base import utils -from clusterfuzz._internal.batch import gcp as batch from clusterfuzz._internal.config import local_config from clusterfuzz._internal.datastore import data_types from clusterfuzz._internal.datastore import ndb_utils +from clusterfuzz._internal.google_cloud_utils import batch from clusterfuzz._internal.google_cloud_utils import credentials from clusterfuzz._internal.metrics import logs @@ -295,7 +296,13 @@ def get_fuzz_tasks(self) -> List[tasks.Task]: return fuzz_tasks -def get_fuzz_tasks(available_cpus: int) -> [tasks.Task]: +def get_fuzz_tasks(project, regions) -> [tasks.Task]: + available_cpus = get_available_cpus(project, regions) + logs.info(f'{available_cpus} available CPUs.') + + if not available_cpus: + return [] + if utils.is_oss_fuzz(): scheduler = OssfuzzFuzzTaskScheduler(available_cpus) else: @@ -377,26 +384,91 @@ def respect_project_max_cpus(num_cpus): return min(max_cpus_per_schedule, num_cpus) +def get_congested_regions() -> List[str]: + """Returns a list of congested regions. The strategy used is as follows: + Run congestion jobs every time this cron is run in each region. + Assuming we run this cron more than 3 times an hour, if there aren't + 3 completed jobs in the last hour, they either failed (unlikely, they are + trivial) or never ran because of congestion. + """ + one_hour_ago = datetime.datetime.utcnow() - datetime.timedelta(hours=1) + congestion_jobs = list( + data_types.CongestionJob.query( + data_types.CongestionJob.timestamp > one_hour_ago)) + + jobs_by_region = collections.defaultdict(list) + for job in congestion_jobs: + if job.region: + jobs_by_region[job.region].append(job) + + congested_regions = [] + for region, jobs in jobs_by_region.items(): + # Sort by timestamp descending. + jobs.sort(key=lambda j: j.timestamp, reverse=True) + # Check the last 3 jobs. + recent_jobs = jobs[:3] + if len(recent_jobs) < 3: + continue + + completed_count = batch.check_congestion_jobs( + [job.job_id for job in recent_jobs]) + if completed_count < 3: + # TODO(metzman): Add some monitoring here. + logs.error(f'Congestion detected in {region}: {completed_count}/3 ' + 'congestion jobs completed in the last hour.') + congested_regions.append(region) + return congested_regions + + +def schedule_congestion_jobs(fuzz_tasks, all_regions): + """Schedules congestion jobs for all regions.""" + # Run a hello world task that finishes very quickly. The job field is + # ignored, but we need one, so pick an arbitrary one. + clusterfuzz_job_type = None + if fuzz_tasks: + clusterfuzz_job_type = fuzz_tasks[0].job + else: + # If no tasks scheduled, try to get a job type from DB to run congestion + # job. + job = data_types.Job.query().get() + if job: + clusterfuzz_job_type = job.name + + if clusterfuzz_job_type: + for region in all_regions: + try: + batch_job_result = batch.create_congestion_job( + clusterfuzz_job_type, gce_region=region) + data_types.CongestionJob( + job_id=batch_job_result.name, region=region).put() + except Exception: + logs.error(f'Failed to create congestion job in {region}.') + + def schedule_fuzz_tasks() -> bool: """Schedules fuzz tasks.""" - multiprocessing.set_start_method('spawn') + try: + multiprocessing.set_start_method('spawn') + except RuntimeError: # Ignore if this was done previously. + pass + batch_config = local_config.BatchConfig() project = batch_config.get('project') - regions = get_batch_regions(batch_config) + all_regions = get_batch_regions(batch_config) + congested_regions = get_congested_regions() + regions = [r for r in all_regions if r not in congested_regions] + start = time.time() - available_cpus = get_available_cpus(project, regions) - logs.info(f'{available_cpus} available CPUs.') - if not available_cpus: - return False + fuzz_tasks = get_fuzz_tasks(project, regions) - fuzz_tasks = get_fuzz_tasks(available_cpus) - if not fuzz_tasks: - logs.error('No fuzz tasks found to schedule.') - return False + if fuzz_tasks: + logs.info(f'Adding {fuzz_tasks} to preprocess queue.') + tasks.bulk_add_tasks(fuzz_tasks, queue=tasks.PREPROCESS_QUEUE, eta_now=True) # pylint: disable=line-too-long + logs.info(f'Scheduled {len(fuzz_tasks)} fuzz tasks.') + else: + logs.info('No fuzz tasks scheduled.') - logs.info(f'Adding {fuzz_tasks} to preprocess queue.') - tasks.bulk_add_tasks(fuzz_tasks, queue=tasks.PREPROCESS_QUEUE, eta_now=True) - logs.info(f'Scheduled {len(fuzz_tasks)} fuzz tasks.') + schedule_congestion_jobs(fuzz_tasks, all_regions) end = time.time() total = end - start diff --git a/src/clusterfuzz/_internal/datastore/data_types.py b/src/clusterfuzz/_internal/datastore/data_types.py index 4e5af47dc2..441cd0e20e 100644 --- a/src/clusterfuzz/_internal/datastore/data_types.py +++ b/src/clusterfuzz/_internal/datastore/data_types.py @@ -1803,3 +1803,21 @@ class FuzzerTaskEvent(Model): def _pre_put_hook(self): self.ttl_expiry_timestamp = ( datetime.datetime.now() + self.FUZZER_EVENT_TTL) + + +class CongestionJob(Model): + """Congestion job. Used to measure congestion in batch.""" + CONGESTION_JOB_TTL = datetime.timedelta(days=3) + + # The job name (ID) in Batch. + job_id = ndb.StringProperty() + # The region the job is running in. + region = ndb.StringProperty() + # Time of creation. + timestamp = ndb.DateTimeProperty(auto_now_add=True) + # Expiration time for this entity. + ttl_expiry_timestamp = ndb.DateTimeProperty() + + def _pre_put_hook(self): + self.ttl_expiry_timestamp = ( + datetime.datetime.now() + self.CONGESTION_JOB_TTL) diff --git a/src/clusterfuzz/_internal/google_cloud_utils/batch.py b/src/clusterfuzz/_internal/google_cloud_utils/batch.py index e804044cef..85ed2ecf5c 100644 --- a/src/clusterfuzz/_internal/google_cloud_utils/batch.py +++ b/src/clusterfuzz/_internal/google_cloud_utils/batch.py @@ -29,3 +29,28 @@ def create_uworker_main_batch_jobs(batch_tasks: List[BatchTask]): """Creates batch jobs.""" service = BatchService() return service.create_uworker_main_batch_jobs(batch_tasks) + + +def create_congestion_job(job_type, gce_region=None): + """Creates a congestion job.""" + service = BatchService() + return service.create_congestion_job(job_type, gce_region) + + +def check_congestion_jobs(job_ids): + """Checks the status of the congestion jobs.""" + from clusterfuzz._internal.batch import gcp + return gcp.check_congestion_jobs(job_ids) + + +def count_queued_or_scheduled_tasks(project: str, region: str): + # TODO(metzman): Move this to BatchService too if needed, but for now + # we can import it or reimplement. + # The master branch version of batch.py didn't seem to have this? + # Let's check if it was removed or if I should keep it. + # It was in my HEAD version. + # It uses `_batch_client` which is removed from here. + # It should probably be in BatchService or GcpBatchClient. + # I'll check gcp.py, it has it. + from clusterfuzz._internal.batch import gcp + return gcp.count_queued_or_scheduled_tasks(project, region) \ No newline at end of file diff --git a/src/clusterfuzz/_internal/tests/appengine/handlers/cron/schedule_fuzz_test.py b/src/clusterfuzz/_internal/tests/appengine/handlers/cron/schedule_fuzz_test.py index 8b960d9977..ea2b095c47 100644 --- a/src/clusterfuzz/_internal/tests/appengine/handlers/cron/schedule_fuzz_test.py +++ b/src/clusterfuzz/_internal/tests/appengine/handlers/cron/schedule_fuzz_test.py @@ -286,3 +286,73 @@ def test_config_limit(self): }] self.assertEqual( schedule_fuzz.get_cpu_usage(self.creds, 'region', 'project'), (2, 0)) + + +@test_utils.with_cloud_emulators('datastore') +class ScheduleFuzzTasksTest(unittest.TestCase): + """Tests for schedule_fuzz_tasks.""" + + def setUp(self): + test_helpers.patch(self, [ + 'clusterfuzz._internal.cron.schedule_fuzz.get_available_cpus', + 'clusterfuzz._internal.cron.schedule_fuzz.get_fuzz_tasks', + 'clusterfuzz._internal.cron.schedule_fuzz.get_batch_regions', + 'clusterfuzz._internal.google_cloud_utils.batch.check_congestion_jobs', + 'clusterfuzz._internal.google_cloud_utils.batch.create_congestion_job', + 'clusterfuzz._internal.base.tasks.bulk_add_tasks', + ]) + self.mock.get_batch_regions.return_value = ['us-central1'] + mock_job = unittest.mock.Mock() + mock_job.name = 'congestion-job' + self.mock.create_congestion_job.return_value = mock_job + + def test_is_congested_true(self): + """Tests that scheduling stops when congested.""" + # Create 3 congestion jobs. + for i in range(3): + data_types.CongestionJob(job_id=f'job-{i}', region='us-central1').put() + + # Mock check_congestion_jobs to return 0 completed. + self.mock.check_congestion_jobs.return_value = 0 + # Mock get_fuzz_tasks to return empty list (simulating no CPUs or other issues) + self.mock.get_fuzz_tasks.return_value = [] + + self.assertTrue(schedule_fuzz.schedule_fuzz_tasks()) + + self.mock.get_fuzz_tasks.assert_called() + # Verify called with empty regions list (2nd argument) + # Args are (project, regions) + call_args = self.mock.get_fuzz_tasks.call_args[0] + self.assertEqual(call_args[1], []) + + def test_is_congested_false(self): + """Tests that scheduling proceeds when not congested.""" + # Create 3 congestion jobs. + for i in range(3): + data_types.CongestionJob(job_id=f'job-{i}', region='us-central1').put() + + # Mock check_congestion_jobs to return 3 completed. + self.mock.check_congestion_jobs.return_value = 3 + mock_task = unittest.mock.Mock() + mock_task.job = 'job1' + self.mock.get_fuzz_tasks.return_value = [mock_task] + + self.assertTrue(schedule_fuzz.schedule_fuzz_tasks()) + self.mock.get_fuzz_tasks.assert_called() + + def test_no_congestion_job_if_no_tasks(self): + """Tests that no congestion job is scheduled if no fuzz tasks.""" + self.mock.get_fuzz_tasks.return_value = [] + + self.assertTrue(schedule_fuzz.schedule_fuzz_tasks()) + self.mock.create_congestion_job.assert_not_called() + + def test_congestion_job_scheduled(self): + """Tests that a congestion job is scheduled when fuzz tasks are.""" + mock_task = unittest.mock.Mock() + mock_task.job = 'job1' + self.mock.get_fuzz_tasks.return_value = [mock_task] + + self.assertTrue(schedule_fuzz.schedule_fuzz_tasks()) + self.mock.create_congestion_job.assert_called_with( + 'job1', gce_region='us-central1') diff --git a/src/clusterfuzz/_internal/tests/core/batch/gcp_test.py b/src/clusterfuzz/_internal/tests/core/batch/gcp_test.py new file mode 100644 index 0000000000..4479a248a9 --- /dev/null +++ b/src/clusterfuzz/_internal/tests/core/batch/gcp_test.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the gcp module.""" +import unittest +from unittest import mock + +from google.cloud import batch_v1 as batch + +from clusterfuzz._internal.batch import gcp +from clusterfuzz._internal.tests.test_libs import helpers +from clusterfuzz._internal.tests.test_libs import test_utils + +class GcpTest(unittest.TestCase): + """Tests for gcp module.""" + + def setUp(self): + helpers.patch(self, [ + 'clusterfuzz._internal.batch.gcp._batch_client', + ]) + self.mock_batch_client_instance = mock.Mock() + self.mock._batch_client.return_value = self.mock_batch_client_instance + + def test_check_congestion_jobs(self): + """Tests that check_congestion_jobs counts correctly.""" + # Create mock jobs with different states + job_succeeded = mock.Mock() + job_succeeded.status.state = batch.JobStatus.State.SUCCEEDED + + job_running = mock.Mock() + job_running.status.state = batch.JobStatus.State.RUNNING + + job_failed = mock.Mock() + job_failed.status.state = batch.JobStatus.State.FAILED + + job_queued = mock.Mock() + job_queued.status.state = batch.JobStatus.State.QUEUED + + # Mock get_job to return these based on job name + def get_job_side_effect(name): + if name == 'job-succeeded': + return job_succeeded + if name == 'job-running': + return job_running + if name == 'job-failed': + return job_failed + if name == 'job-queued': + return job_queued + raise Exception("Job not found") + + self.mock_batch_client_instance.get_job.side_effect = get_job_side_effect + + # Check that SUCCEEDED, RUNNING, FAILED are counted (3 total) + # QUEUED is not counted + # Non-existent job is not counted + job_ids = ['job-succeeded', 'job-running', 'job-failed', 'job-queued', 'job-missing'] + count = gcp.check_congestion_jobs(job_ids) + + self.assertEqual(count, 3) diff --git a/src/clusterfuzz/_internal/tests/core/batch/service_test.py b/src/clusterfuzz/_internal/tests/core/batch/service_test.py index 613a32603e..6d401f1f97 100644 --- a/src/clusterfuzz/_internal/tests/core/batch/service_test.py +++ b/src/clusterfuzz/_internal/tests/core/batch/service_test.py @@ -31,11 +31,13 @@ UUIDS = [f'00000000-0000-0000-0000-{str(i).zfill(12)}' for i in range(100)] -def _get_expected_task_spec(batch_workload_spec): +def _get_expected_task_spec(batch_workload_spec, commands=None): """Gets the task spec based on the batch workload spec.""" runnable = batch.Runnable() runnable.container = batch.Runnable.Container() runnable.container.image_uri = batch_workload_spec.docker_image + if commands: + runnable.container.commands = commands clusterfuzz_release = batch_workload_spec.clusterfuzz_release runnable.container.options = ( '--memory-swappiness=40 --shm-size=1.9g --rm --net=host ' @@ -103,7 +105,8 @@ def _get_expected_allocation_policy(spec): return allocation_policy -def _get_expected_create_request(job_name_uuid, spec, input_urls): +def _get_expected_create_request(job_name_uuid, spec, input_urls, + commands=None): """Constructs and returns a `batch.CreateJobRequest` object. This function builds a complete `CreateJobRequest` for the GCP Batch service, @@ -115,7 +118,7 @@ def _get_expected_create_request(job_name_uuid, spec, input_urls): project_id = spec.project parent = f'projects/{project_id}/locations/{spec.gce_region}' - task_spec = _get_expected_task_spec(spec) + task_spec = _get_expected_task_spec(spec, commands=commands) task_environments = [ batch.Environment(variables={'UWORKER_INPUT_DOWNLOAD_URL': url}) @@ -257,6 +260,42 @@ def test_create_uworker_main_batch_job(self): expected_create_request) self.assertEqual(result, 'job') + def test_create_congestion_job(self): + """Tests that create_congestion_job works as expected.""" + # Create mock data. + spec1 = service.BatchWorkloadSpec( + clusterfuzz_release='release1', + disk_size_gb=10, + disk_type='type1', + docker_image='image1', + user_data='user_data1', + service_account_email='email1', + subnetwork='subnetwork1', + preemptible=True, + project='project1', + machine_type='machine1', + network='network1', + gce_region='region1', + priority=0, + max_run_duration='1s', + retry=False) + with mock.patch('clusterfuzz._internal.batch.service._get_specs_from_config' + ) as mock_get_specs_from_config: + mock_get_specs_from_config.return_value = { + ('fuzz', 'job1'): spec1, + } + self.mock_batch_client_instance.create_job.return_value = 'job' + + # Call the function. + result = self.batch_service.create_congestion_job('job1') + + # Assert that create_job was called with the correct arguments. + expected_create_request = _get_expected_create_request( + UUIDS[0], spec1, ['CONGESTION'], commands=['echo', 'hello']) + self.mock_batch_client_instance.create_job.assert_called_with( + expected_create_request) + self.assertEqual(result, 'job') + @test_utils.with_cloud_emulators('datastore') class IsRemoteTaskTest(unittest.TestCase):