SC2: Region access rule speedups (#5426)
This commit is contained in:
@@ -10,6 +10,10 @@ from BaseClasses import CollectionState
|
||||
if TYPE_CHECKING:
|
||||
from .nodes import SC2MOGenMission
|
||||
|
||||
def always_true(state: CollectionState) -> bool:
|
||||
"""Helper method to avoid creating trivial lambdas"""
|
||||
return True
|
||||
|
||||
|
||||
class EntryRule(ABC):
|
||||
buffer_fulfilled: bool
|
||||
@@ -60,6 +64,11 @@ class EntryRule(ABC):
|
||||
"""Used in the client to determine accessibility while playing and to populate tooltips."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
"""Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission."""
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleData(ABC):
|
||||
@@ -103,6 +112,11 @@ class BeatMissionsEntryRule(EntryRule):
|
||||
mission_ids,
|
||||
resolved_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if len(self.missions_to_beat) > 0:
|
||||
return self.missions_to_beat[0]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule):
|
||||
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
|
||||
super().__init__()
|
||||
self.missions_to_count = missions_to_count
|
||||
if target_amount == -1 or target_amount > len(missions_to_count):
|
||||
if target_amount <= -1 or target_amount > len(missions_to_count):
|
||||
self.target_amount = len(missions_to_count)
|
||||
else:
|
||||
self.target_amount = target_amount
|
||||
@@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule):
|
||||
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
|
||||
|
||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||
return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count)
|
||||
if self.target_amount == 0:
|
||||
return always_true
|
||||
|
||||
beat_items = [mission.beat_item() for mission in self.missions_to_count]
|
||||
def count_missions(state: CollectionState) -> bool:
|
||||
count = 0
|
||||
for mission in range(len(self.missions_to_count)):
|
||||
if state.has(beat_items[mission], player):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
return count_missions
|
||||
|
||||
def to_slot_data(self) -> RuleData:
|
||||
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
|
||||
@@ -165,6 +192,11 @@ class CountMissionsEntryRule(EntryRule):
|
||||
self.target_amount,
|
||||
resolved_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if self.target_amount > 0 and self.target_amount == len(self.missions_to_count):
|
||||
return self.missions_to_count[0]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule):
|
||||
self.rule_id = rule_id
|
||||
self.rules_to_check = rules_to_check
|
||||
self.min_depth = -1
|
||||
if target_amount == -1 or target_amount > len(rules_to_check):
|
||||
if target_amount <= -1 or target_amount > len(rules_to_check):
|
||||
self.target_amount = len(rules_to_check)
|
||||
else:
|
||||
self.target_amount = target_amount
|
||||
|
||||
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
|
||||
return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check)
|
||||
if len(self.rules_to_check) == 0:
|
||||
return True
|
||||
count = 0
|
||||
for rule in self.rules_to_check:
|
||||
if rule.is_fulfilled(beaten_missions, in_region_check):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
|
||||
if len(self.rules_to_check) == 0:
|
||||
@@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule):
|
||||
|
||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
|
||||
return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas)
|
||||
if self.target_amount == 0:
|
||||
return always_true
|
||||
if len(sub_lambdas) == 1:
|
||||
return sub_lambdas[0]
|
||||
|
||||
def count_rules(state: CollectionState) -> bool:
|
||||
count = 0
|
||||
for sub_lambda in sub_lambdas:
|
||||
if sub_lambda(state):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
return count_rules
|
||||
|
||||
def to_slot_data(self) -> SubRuleRuleData:
|
||||
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
|
||||
@@ -244,6 +298,14 @@ class SubRuleEntryRule(EntryRule):
|
||||
sub_rules,
|
||||
self.target_amount
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if self.target_amount > 0 and self.target_amount == len(self.rules_to_check):
|
||||
for sub_rule in self.rules_to_check:
|
||||
mandatory_mission = sub_rule.find_mandatory_mission()
|
||||
if mandatory_mission is not None:
|
||||
return mandatory_mission
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -362,6 +424,9 @@ class ItemEntryRule(EntryRule):
|
||||
item_ids,
|
||||
visual_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user