Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions sc2/bot_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,13 @@ def expansion_locations_dict(self) -> dict[Point2, Units]:
expansion_locations: dict[Point2, Units] = {pos: Units([], self) for pos in self._expansion_positions_list}
for resource in self.resources:
# It may be that some resources are not mapped to an expansion location
exp_position: Point2 | None = self._resource_location_to_expansion_position_dict.get(
exp_positions: set[Point2] | None = self._resource_location_to_expansion_position_dict.get(
resource.position, None
)
if exp_position:
assert exp_position in expansion_locations
expansion_locations[exp_position].append(resource)
if exp_positions:
for exp_position in exp_positions:
assert exp_position in expansion_locations
expansion_locations[exp_position].append(resource)
return expansion_locations

@property
Expand Down
211 changes: 180 additions & 31 deletions sc2/bot_ai_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _initialize_variables(self) -> None:
self._all_units_previous_map: dict[int, Unit] = {}
self._previous_upgrades: set[UpgradeId] = set()
self._expansion_positions_list: list[Point2] = []
self._resource_location_to_expansion_position_dict: dict[Point2, Point2] = {}
self._resource_location_to_expansion_position_dict: dict[Point2, set[Point2]] = {}
self._time_before_step: float = 0
self._time_after_step: float = 0
self._min_step_time: float = math.inf
Expand Down Expand Up @@ -177,14 +177,154 @@ def expansion_locations(self) -> dict[Point2, Units]:
)
return self.expansion_locations_dict

def _cluster_center(self, group: list[Unit]) -> Point2:
"""
Calculates the geometric center (centroid) of a given group of units.

Parameters:
group: A list of Unit objects representing the group of units for
which the center is to be calculated.

Raises:
ValueError: If the provided group is empty.

Returns:
Point2: The calculated centroid of the group as a Point2 object.
"""
if not group:
raise ValueError("Cannot calculate center of empty group")

total_x = total_y = 0
for unit in group:
total_x += unit.position.x
total_y += unit.position.y

count = len(group)
return Point2((total_x / count, total_y / count))

def _find_expansion_location(
self, resources: Units | list[Unit], amount: int, offsets: list[tuple[float, float]]
) -> Point2:
"""
Finds the most suitable expansion location for resources.

Parameters:
resources: The list of resource entities or units near which the
expansion location needs to be found.
amount: The total number of resource entities or units to consider.
offsets (list[tuple[float, float]): A list of coordinate pairs denoting position
offsets to consider around the center of resources.

Returns:
The calculated optimal expansion Point2 if a suitable position is found;
otherwise, None.
"""
# Normal single expansion logic for regular bases
# Calculate center, round and add 0.5 because expansion location will have (x.5, y.5)
# coordinates because bases have size 5.
center_x = int(sum(resource.position.x for resource in resources) / amount) + 0.5
center_y = int(sum(resource.position.y for resource in resources) / amount) + 0.5
possible_points = (Point2((offset[0] + center_x, offset[1] + center_y)) for offset in offsets)
# Filter out points that are too near
possible_points = [
point
for point in possible_points
# Check if point can be built on
if self.game_info.placement_grid[point.rounded] == 1
# Check if all resources have enough space to point
and all(
point.distance_to(resource) >= (7 if resource._proto.unit_type in geyser_ids else 6)
for resource in resources
)
]
# Choose best fitting point
result: Point2 = min(
possible_points, key=lambda point: sum(point.distance_to(resource_) for resource_ in resources)
)
return result

def _has_opposite_side_geyser_layout(self, minerals: list[Unit], gas_geysers: list[Unit]) -> bool:
"""
Determines whether the gas geysers have an opposite-side mineral line layout.

The method evaluates if two gas geysers are located on opposite sides of a
mineral line.
If this returns True we consider this location has 2 valid expansion locations
either side of the mineral line.

Parameters:
minerals:
A list of mineral fields at this location.
gas_geysers : list[Unit]
A list of gas geysers at this location.

Returns:
bool
True if the geysers fulfill the opposite-side layout condition with
respect to the mineral line, otherwise False.
"""
# Need exactly 2 geysers and enough minerals for a line
if len(gas_geysers) != 2 or len(minerals) < 6:
return False

# Find the two minerals that are furthest apart
max_distance: float = 0.0
mineral_1: Unit = minerals[0]
mineral_2: Unit = minerals[1]

for i, m1 in enumerate(minerals):
for m2 in minerals[i + 1 :]:
distance = m1.distance_to(m2)
if distance > max_distance:
max_distance = distance
mineral_1 = m1
mineral_2 = m2

# ensure line is long enough
if max_distance < 4:
return False

# Create line from the two furthest minerals
x1, y1 = mineral_1.position.x, mineral_1.position.y
x2, y2 = mineral_2.position.x, mineral_2.position.y

geyser_1, geyser_2 = gas_geysers

# Check if the mineral line is more vertical than horizontal
if abs(x2 - x1) < 0.1:
# Vertical line: use x-coordinate to determine sides
line_x = (x1 + x2) / 2

side_1 = geyser_1.position.x - line_x
side_2 = geyser_2.position.x - line_x

# Must be on opposite sides and far enough from the line
return side_1 * side_2 < 0 and abs(side_1) > 3 and abs(side_2) > 3

# Calculate line equation: y = mx + b
slope = (y2 - y1) / (x2 - x1)
intercept = y1 - slope * x1

# Function to determine which side of the line a point is on
def side_of_line(point: Point2) -> float:
return point.y - slope * point.x - intercept

side_1 = side_of_line(geyser_1.position)
side_2 = side_of_line(geyser_2.position)

# Check if geysers are on opposite sides
opposite_sides = side_1 * side_2 < 0

return opposite_sides

@final
def _find_expansion_locations(self) -> None:
"""Ran once at the start of the game to calculate expansion locations."""
# Idea: create a group for every resource, then merge these groups if
# any resource in a group is closer than a threshold to any resource of another group

# Distance we group resources by
resource_spread_threshold: float = 8.5
resource_spread_threshold: float = 10.5
# Create a group for every resource
resource_groups: list[list[Unit]] = [
[resource]
Expand All @@ -200,22 +340,23 @@ def _find_expansion_locations(self) -> None:
for group_a, group_b in itertools.combinations(resource_groups, 2):
# Check if any pair of resource of these groups is closer than threshold together
# And that they are on the same terrain level
if any(
resource_a.distance_to(resource_b) <= resource_spread_threshold
# check if terrain height measurement at resources is within 10 units
# this is since some older maps have inconsistent terrain height
# tiles at certain expansion locations
and abs(height_grid[resource_a.position.rounded] - height_grid[resource_b.position.rounded]) <= 10
for resource_a, resource_b in itertools.product(group_a, group_b)
center_a = self._cluster_center(group_a)
center_b = self._cluster_center(group_b)

if center_a.distance_to(center_b) <= resource_spread_threshold and all(
abs(height_grid[res_a.position.rounded] - height_grid[res_b.position.rounded]) <= 10
for res_a in group_a
for res_b in group_b
):
# Remove the single groups and add the merged group
resource_groups.remove(group_a)
resource_groups.remove(group_b)
resource_groups.append(group_a + group_b)
merged_group = True
break

# Distance offsets we apply to center of each resource group to find expansion position
offset_range = 7
offset_range: int = 7
offsets = [
(x, y)
for x, y in itertools.product(range(-offset_range, offset_range + 1), repeat=2)
Expand All @@ -227,33 +368,41 @@ def _find_expansion_locations(self) -> None:
for resources in resource_groups:
# Possible expansion points
amount = len(resources)
# Calculate center, round and add 0.5 because expansion location will have (x.5, y.5)
# coordinates because bases have size 5.
center_x = int(sum(resource.position.x for resource in resources) / amount) + 0.5
center_y = int(sum(resource.position.y for resource in resources) / amount) + 0.5
possible_points = (Point2((offset[0] + center_x, offset[1] + center_y)) for offset in offsets)
# Filter out points that are too near
possible_points = (
point
for point in possible_points
# Check if point can be built on
if self.game_info.placement_grid[point.rounded] == 1
# Check if all resources have enough space to point
and all(
point.distance_to(resource) >= (7 if resource._proto.unit_type in geyser_ids else 6)
for resource in resources
)
)
# this check is needed for TorchesAIE where the gold mineral wall has a
# unit type of `RichMineralField` so we can only filter out by amount of resources
if amount > 12:
continue

minerals = [r for r in resources if r._proto.unit_type not in geyser_ids]
gas_geysers = [r for r in resources if r._proto.unit_type in geyser_ids]

# Check if we have exactly 2 gas geysers positioned above/below the mineral line
# Needed for TorchesAIE where one gold base has 2 expansion locations
if self._has_opposite_side_geyser_layout(minerals, gas_geysers):
# Create expansion locations for each geyser + minerals
for geyser in gas_geysers:
local_resources = minerals + [geyser]
result: Point2 = self._find_expansion_location(local_resources, len(local_resources), offsets)
centers[result] = local_resources
# Put all expansion locations in a list
self._expansion_positions_list.append(result)
# Maps all resource positions to the expansion position
for resource in local_resources:
if resource.position in self._resource_location_to_expansion_position_dict:
self._resource_location_to_expansion_position_dict[resource.position].add(result)
else:
self._resource_location_to_expansion_position_dict[resource.position] = {result}

continue

# Choose best fitting point
result: Point2 = min(
possible_points, key=lambda point: sum(point.distance_to(resource_) for resource_ in resources)
)
result: Point2 = self._find_expansion_location(resources, amount, offsets)
centers[result] = resources
# Put all expansion locations in a list
self._expansion_positions_list.append(result)
# Maps all resource positions to the expansion position
for resource in resources:
self._resource_location_to_expansion_position_dict[resource.position] = result
self._resource_location_to_expansion_position_dict[resource.position] = {result}

@final
def _correct_zerg_supply(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions test/generate_pickle_files_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ def main():
"OxideAIE",
"PaladinoTerminalLE",
"ParaSiteLE",
"PersephoneAIE",
"PillarsofGold506",
"PillarsofGoldLE",
"PortAleksanderLE",
"PrimusQ9",
"ProximaStationLE",
"PylonAIE",
"RedshiftLE",
"Reminiscence",
"RomanticideAIE",
Expand All @@ -193,6 +195,7 @@ def main():
"StasisLE",
"TheTimelessVoid",
"ThunderbirdLE",
"TorchesAIE",
"Treachery",
"Triton",
"Urzagol",
Expand Down
Binary file added test/pickle_data/PersephoneAIE.xz
Binary file not shown.
Binary file added test/pickle_data/PylonAIE.xz
Binary file not shown.
Binary file added test/pickle_data/TorchesAIE.xz
Binary file not shown.
5 changes: 4 additions & 1 deletion test/test_pickled_ramp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from loguru import logger

from sc2.game_info import Ramp
from sc2.ids.unit_typeid import UnitTypeId
from sc2.position import Point2
from sc2.unit import Unit
from sc2.units import Units
Expand All @@ -36,6 +37,8 @@ class TestClass:
# Load all pickle files and convert them into bot objects from raw data (game_data, game_info, game_state)
scenarios = [(map_path.name, {"map_path": map_path}) for map_path in MAPS]

MAPS_WITH_ODD_EXPANSION_COUNT: set[UnitTypeId] = {"Persephone AIE", "StargazersAIE", "Stasis LE"}

def test_main_base_ramp(self, map_path: Path):
bot = get_map_specific_bot(map_path)
# pyre-ignore[16]
Expand Down Expand Up @@ -105,7 +108,7 @@ def test_bot_ai(self, map_path: Path):
# On N player maps, it is expected that there are N*X bases because of symmetry, at least for maps designed for 1vs1
# Those maps in the list have an un-even expansion count
# pyre-ignore[16]
expect_even_expansion_count = 1 if bot.game_info.map_name in ["StargazersAIE", "Stasis LE"] else 0
expect_even_expansion_count = 1 if bot.game_info.map_name in self.MAPS_WITH_ODD_EXPANSION_COUNT else 0
assert (
len(bot.expansion_locations_list) % (len(bot.enemy_start_locations) + 1) == expect_even_expansion_count
), f"{bot.expansion_locations_list}"
Expand Down
Loading