SC2: Region access rule speedups (#5426)

This commit is contained in:
Salzkorn
2025-09-12 23:48:29 +02:00
committed by GitHub
parent 76a8b0d582
commit 4e085894d2
2 changed files with 120 additions and 18 deletions

View File

@@ -10,6 +10,10 @@ from BaseClasses import CollectionState
if TYPE_CHECKING:
from .nodes import SC2MOGenMission
def always_true(state: CollectionState) -> bool:
"""Helper method to avoid creating trivial lambdas"""
return True
class EntryRule(ABC):
buffer_fulfilled: bool
@@ -60,6 +64,11 @@ class EntryRule(ABC):
"""Used in the client to determine accessibility while playing and to populate tooltips."""
pass
@abstractmethod
def find_mandatory_mission(self) -> SC2MOGenMission | None:
"""Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission."""
return None
@dataclass
class RuleData(ABC):
@@ -103,6 +112,11 @@ class BeatMissionsEntryRule(EntryRule):
mission_ids,
resolved_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if len(self.missions_to_beat) > 0:
return self.missions_to_beat[0]
return None
@dataclass
@@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule):
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
super().__init__()
self.missions_to_count = missions_to_count
if target_amount == -1 or target_amount > len(missions_to_count):
if target_amount <= -1 or target_amount > len(missions_to_count):
self.target_amount = len(missions_to_count)
else:
self.target_amount = target_amount
@@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule):
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count)
if self.target_amount == 0:
return always_true
beat_items = [mission.beat_item() for mission in self.missions_to_count]
def count_missions(state: CollectionState) -> bool:
count = 0
for mission in range(len(self.missions_to_count)):
if state.has(beat_items[mission], player):
count += 1
if count == self.target_amount:
return True
return False
return count_missions
def to_slot_data(self) -> RuleData:
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
@@ -165,6 +192,11 @@ class CountMissionsEntryRule(EntryRule):
self.target_amount,
resolved_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if self.target_amount > 0 and self.target_amount == len(self.missions_to_count):
return self.missions_to_count[0]
return None
@dataclass
@@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule):
self.rule_id = rule_id
self.rules_to_check = rules_to_check
self.min_depth = -1
if target_amount == -1 or target_amount > len(rules_to_check):
if target_amount <= -1 or target_amount > len(rules_to_check):
self.target_amount = len(rules_to_check)
else:
self.target_amount = target_amount
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check)
if len(self.rules_to_check) == 0:
return True
count = 0
for rule in self.rules_to_check:
if rule.is_fulfilled(beaten_missions, in_region_check):
count += 1
if count == self.target_amount:
return True
return False
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
if len(self.rules_to_check) == 0:
@@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule):
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas)
if self.target_amount == 0:
return always_true
if len(sub_lambdas) == 1:
return sub_lambdas[0]
def count_rules(state: CollectionState) -> bool:
count = 0
for sub_lambda in sub_lambdas:
if sub_lambda(state):
count += 1
if count == self.target_amount:
return True
return False
return count_rules
def to_slot_data(self) -> SubRuleRuleData:
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
@@ -244,6 +298,14 @@ class SubRuleEntryRule(EntryRule):
sub_rules,
self.target_amount
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if self.target_amount > 0 and self.target_amount == len(self.rules_to_check):
for sub_rule in self.rules_to_check:
mandatory_mission = sub_rule.find_mandatory_mission()
if mandatory_mission is not None:
return mandatory_mission
return None
@dataclass
@@ -362,6 +424,9 @@ class ItemEntryRule(EntryRule):
item_ids,
visual_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
return None
@dataclass

View File

@@ -491,23 +491,60 @@ def make_connections(mission_order: SC2MOGenMissionOrder, world: 'SC2World'):
for layout in campaign.layouts:
for mission in layout.missions:
if not mission.option_empty:
mission_uses_rule = mission.entry_rule.target_amount > 0
mission_rule = mission.entry_rule.to_lambda(player)
mandatory_prereq = mission.entry_rule.find_mandatory_mission()
# Only layout entrances need to consider campaign & layout prerequisites
if mission.option_entrance:
campaign_rule = mission.parent().parent().entry_rule.to_lambda(player)
layout_rule = mission.parent().entry_rule.to_lambda(player)
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \
campaign_rule(state) and layout_rule(state) and mission_rule(state)
else:
campaign_uses_rule = campaign.entry_rule.target_amount > 0
campaign_rule = campaign.entry_rule.to_lambda(player)
layout_uses_rule = layout.entry_rule.target_amount > 0
layout_rule = layout.entry_rule.to_lambda(player)
# Any mandatory prerequisite mission is good enough
mandatory_prereq = campaign.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
mandatory_prereq = layout.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
# Avoid calling obviously unused lambdas
if campaign_uses_rule:
if layout_uses_rule:
if mission_uses_rule:
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \
campaign_rule(state) and layout_rule(state) and mission_rule(state)
else:
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule: \
campaign_rule(state) and layout_rule(state)
else:
if mission_uses_rule:
unlock_rule = lambda state, campaign_rule=campaign_rule, mission_rule=mission_rule: \
campaign_rule(state) and mission_rule(state)
else:
unlock_rule = campaign_rule
elif layout_uses_rule:
if mission_uses_rule:
unlock_rule = lambda state, layout_rule=layout_rule, mission_rule=mission_rule: \
layout_rule(state) and mission_rule(state)
else:
unlock_rule = layout_rule
elif mission_uses_rule:
unlock_rule = mission_rule
else:
unlock_rule = None
elif mission_uses_rule:
unlock_rule = mission_rule
# Individually connect to previous missions
for prev_mission in mission.prev:
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name,
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
# If there are no previous missions, connect to Menu instead
if len(mission.prev) == 0:
connect(world, names, "Menu", mission.mission.mission_name,
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
else:
unlock_rule = None
# Connect to a discovered mandatory mission if possible
if mandatory_prereq is not None:
connect(world, names, mandatory_prereq.mission.mission_name, mission.mission.mission_name, unlock_rule)
else:
# If no mission is known to be mandatory, connect to all previous missions instead
for prev_mission in mission.prev:
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name, unlock_rule)
# As a last resort connect to Menu
if len(mission.prev) == 0:
connect(world, names, "Menu", mission.mission.mission_name, unlock_rule)
def connect(world: 'SC2World', used_names: Dict[str, int], source: str, target: str,