SC2: Region access rule speedups (#5426)
This commit is contained in:
@@ -10,6 +10,10 @@ from BaseClasses import CollectionState
|
||||
if TYPE_CHECKING:
|
||||
from .nodes import SC2MOGenMission
|
||||
|
||||
def always_true(state: CollectionState) -> bool:
|
||||
"""Helper method to avoid creating trivial lambdas"""
|
||||
return True
|
||||
|
||||
|
||||
class EntryRule(ABC):
|
||||
buffer_fulfilled: bool
|
||||
@@ -60,6 +64,11 @@ class EntryRule(ABC):
|
||||
"""Used in the client to determine accessibility while playing and to populate tooltips."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
"""Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission."""
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuleData(ABC):
|
||||
@@ -104,6 +113,11 @@ class BeatMissionsEntryRule(EntryRule):
|
||||
resolved_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if len(self.missions_to_beat) > 0:
|
||||
return self.missions_to_beat[0]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatMissionsRuleData(RuleData):
|
||||
@@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule):
|
||||
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
|
||||
super().__init__()
|
||||
self.missions_to_count = missions_to_count
|
||||
if target_amount == -1 or target_amount > len(missions_to_count):
|
||||
if target_amount <= -1 or target_amount > len(missions_to_count):
|
||||
self.target_amount = len(missions_to_count)
|
||||
else:
|
||||
self.target_amount = target_amount
|
||||
@@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule):
|
||||
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
|
||||
|
||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||
return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count)
|
||||
if self.target_amount == 0:
|
||||
return always_true
|
||||
|
||||
beat_items = [mission.beat_item() for mission in self.missions_to_count]
|
||||
def count_missions(state: CollectionState) -> bool:
|
||||
count = 0
|
||||
for mission in range(len(self.missions_to_count)):
|
||||
if state.has(beat_items[mission], player):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
return count_missions
|
||||
|
||||
def to_slot_data(self) -> RuleData:
|
||||
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
|
||||
@@ -166,6 +193,11 @@ class CountMissionsEntryRule(EntryRule):
|
||||
resolved_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if self.target_amount > 0 and self.target_amount == len(self.missions_to_count):
|
||||
return self.missions_to_count[0]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountMissionsRuleData(RuleData):
|
||||
@@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule):
|
||||
self.rule_id = rule_id
|
||||
self.rules_to_check = rules_to_check
|
||||
self.min_depth = -1
|
||||
if target_amount == -1 or target_amount > len(rules_to_check):
|
||||
if target_amount <= -1 or target_amount > len(rules_to_check):
|
||||
self.target_amount = len(rules_to_check)
|
||||
else:
|
||||
self.target_amount = target_amount
|
||||
|
||||
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
|
||||
return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check)
|
||||
if len(self.rules_to_check) == 0:
|
||||
return True
|
||||
count = 0
|
||||
for rule in self.rules_to_check:
|
||||
if rule.is_fulfilled(beaten_missions, in_region_check):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
|
||||
if len(self.rules_to_check) == 0:
|
||||
@@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule):
|
||||
|
||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
|
||||
return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas)
|
||||
if self.target_amount == 0:
|
||||
return always_true
|
||||
if len(sub_lambdas) == 1:
|
||||
return sub_lambdas[0]
|
||||
|
||||
def count_rules(state: CollectionState) -> bool:
|
||||
count = 0
|
||||
for sub_lambda in sub_lambdas:
|
||||
if sub_lambda(state):
|
||||
count += 1
|
||||
if count == self.target_amount:
|
||||
return True
|
||||
return False
|
||||
|
||||
return count_rules
|
||||
|
||||
def to_slot_data(self) -> SubRuleRuleData:
|
||||
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
|
||||
@@ -245,6 +299,14 @@ class SubRuleEntryRule(EntryRule):
|
||||
self.target_amount
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
if self.target_amount > 0 and self.target_amount == len(self.rules_to_check):
|
||||
for sub_rule in self.rules_to_check:
|
||||
mandatory_mission = sub_rule.find_mandatory_mission()
|
||||
if mandatory_mission is not None:
|
||||
return mandatory_mission
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubRuleRuleData(RuleData):
|
||||
@@ -363,6 +425,9 @@ class ItemEntryRule(EntryRule):
|
||||
visual_reqs
|
||||
)
|
||||
|
||||
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ItemRuleData(RuleData):
|
||||
|
||||
@@ -491,23 +491,60 @@ def make_connections(mission_order: SC2MOGenMissionOrder, world: 'SC2World'):
|
||||
for layout in campaign.layouts:
|
||||
for mission in layout.missions:
|
||||
if not mission.option_empty:
|
||||
mission_uses_rule = mission.entry_rule.target_amount > 0
|
||||
mission_rule = mission.entry_rule.to_lambda(player)
|
||||
mandatory_prereq = mission.entry_rule.find_mandatory_mission()
|
||||
# Only layout entrances need to consider campaign & layout prerequisites
|
||||
if mission.option_entrance:
|
||||
campaign_rule = mission.parent().parent().entry_rule.to_lambda(player)
|
||||
layout_rule = mission.parent().entry_rule.to_lambda(player)
|
||||
campaign_uses_rule = campaign.entry_rule.target_amount > 0
|
||||
campaign_rule = campaign.entry_rule.to_lambda(player)
|
||||
layout_uses_rule = layout.entry_rule.target_amount > 0
|
||||
layout_rule = layout.entry_rule.to_lambda(player)
|
||||
|
||||
# Any mandatory prerequisite mission is good enough
|
||||
mandatory_prereq = campaign.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
|
||||
mandatory_prereq = layout.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
|
||||
|
||||
# Avoid calling obviously unused lambdas
|
||||
if campaign_uses_rule:
|
||||
if layout_uses_rule:
|
||||
if mission_uses_rule:
|
||||
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \
|
||||
campaign_rule(state) and layout_rule(state) and mission_rule(state)
|
||||
else:
|
||||
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule: \
|
||||
campaign_rule(state) and layout_rule(state)
|
||||
else:
|
||||
if mission_uses_rule:
|
||||
unlock_rule = lambda state, campaign_rule=campaign_rule, mission_rule=mission_rule: \
|
||||
campaign_rule(state) and mission_rule(state)
|
||||
else:
|
||||
unlock_rule = campaign_rule
|
||||
elif layout_uses_rule:
|
||||
if mission_uses_rule:
|
||||
unlock_rule = lambda state, layout_rule=layout_rule, mission_rule=mission_rule: \
|
||||
layout_rule(state) and mission_rule(state)
|
||||
else:
|
||||
unlock_rule = layout_rule
|
||||
elif mission_uses_rule:
|
||||
unlock_rule = mission_rule
|
||||
# Individually connect to previous missions
|
||||
else:
|
||||
unlock_rule = None
|
||||
elif mission_uses_rule:
|
||||
unlock_rule = mission_rule
|
||||
else:
|
||||
unlock_rule = None
|
||||
|
||||
# Connect to a discovered mandatory mission if possible
|
||||
if mandatory_prereq is not None:
|
||||
connect(world, names, mandatory_prereq.mission.mission_name, mission.mission.mission_name, unlock_rule)
|
||||
else:
|
||||
# If no mission is known to be mandatory, connect to all previous missions instead
|
||||
for prev_mission in mission.prev:
|
||||
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name,
|
||||
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
|
||||
# If there are no previous missions, connect to Menu instead
|
||||
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name, unlock_rule)
|
||||
# As a last resort connect to Menu
|
||||
if len(mission.prev) == 0:
|
||||
connect(world, names, "Menu", mission.mission.mission_name,
|
||||
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
|
||||
connect(world, names, "Menu", mission.mission.mission_name, unlock_rule)
|
||||
|
||||
|
||||
def connect(world: 'SC2World', used_names: Dict[str, int], source: str, target: str,
|
||||
|
||||
Reference in New Issue
Block a user