SC2: Region access rule speedups (#5426)

This commit is contained in:
Salzkorn
2025-09-12 23:48:29 +02:00
committed by GitHub
parent 76a8b0d582
commit 4e085894d2
2 changed files with 120 additions and 18 deletions

View File

@@ -10,6 +10,10 @@ from BaseClasses import CollectionState
if TYPE_CHECKING:
from .nodes import SC2MOGenMission
def always_true(state: CollectionState) -> bool:
"""Helper method to avoid creating trivial lambdas"""
return True
class EntryRule(ABC):
buffer_fulfilled: bool
@@ -60,6 +64,11 @@ class EntryRule(ABC):
"""Used in the client to determine accessibility while playing and to populate tooltips."""
pass
@abstractmethod
def find_mandatory_mission(self) -> SC2MOGenMission | None:
"""Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission."""
return None
@dataclass
class RuleData(ABC):
@@ -103,6 +112,11 @@ class BeatMissionsEntryRule(EntryRule):
mission_ids,
resolved_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if len(self.missions_to_beat) > 0:
return self.missions_to_beat[0]
return None
@dataclass
@@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule):
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
super().__init__()
self.missions_to_count = missions_to_count
if target_amount == -1 or target_amount > len(missions_to_count):
if target_amount <= -1 or target_amount > len(missions_to_count):
self.target_amount = len(missions_to_count)
else:
self.target_amount = target_amount
@@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule):
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count)
if self.target_amount == 0:
return always_true
beat_items = [mission.beat_item() for mission in self.missions_to_count]
def count_missions(state: CollectionState) -> bool:
count = 0
for mission in range(len(self.missions_to_count)):
if state.has(beat_items[mission], player):
count += 1
if count == self.target_amount:
return True
return False
return count_missions
def to_slot_data(self) -> RuleData:
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
@@ -165,6 +192,11 @@ class CountMissionsEntryRule(EntryRule):
self.target_amount,
resolved_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if self.target_amount > 0 and self.target_amount == len(self.missions_to_count):
return self.missions_to_count[0]
return None
@dataclass
@@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule):
self.rule_id = rule_id
self.rules_to_check = rules_to_check
self.min_depth = -1
if target_amount == -1 or target_amount > len(rules_to_check):
if target_amount <= -1 or target_amount > len(rules_to_check):
self.target_amount = len(rules_to_check)
else:
self.target_amount = target_amount
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check)
if len(self.rules_to_check) == 0:
return True
count = 0
for rule in self.rules_to_check:
if rule.is_fulfilled(beaten_missions, in_region_check):
count += 1
if count == self.target_amount:
return True
return False
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
if len(self.rules_to_check) == 0:
@@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule):
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas)
if self.target_amount == 0:
return always_true
if len(sub_lambdas) == 1:
return sub_lambdas[0]
def count_rules(state: CollectionState) -> bool:
count = 0
for sub_lambda in sub_lambdas:
if sub_lambda(state):
count += 1
if count == self.target_amount:
return True
return False
return count_rules
def to_slot_data(self) -> SubRuleRuleData:
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
@@ -244,6 +298,14 @@ class SubRuleEntryRule(EntryRule):
sub_rules,
self.target_amount
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
if self.target_amount > 0 and self.target_amount == len(self.rules_to_check):
for sub_rule in self.rules_to_check:
mandatory_mission = sub_rule.find_mandatory_mission()
if mandatory_mission is not None:
return mandatory_mission
return None
@dataclass
@@ -362,6 +424,9 @@ class ItemEntryRule(EntryRule):
item_ids,
visual_reqs
)
def find_mandatory_mission(self) -> SC2MOGenMission | None:
return None
@dataclass