SC2: Region access rule speedups (#5426)
This commit is contained in:
@@ -10,6 +10,10 @@ from BaseClasses import CollectionState
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .nodes import SC2MOGenMission
|
from .nodes import SC2MOGenMission
|
||||||
|
|
||||||
|
def always_true(state: CollectionState) -> bool:
|
||||||
|
"""Helper method to avoid creating trivial lambdas"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class EntryRule(ABC):
|
class EntryRule(ABC):
|
||||||
buffer_fulfilled: bool
|
buffer_fulfilled: bool
|
||||||
@@ -60,6 +64,11 @@ class EntryRule(ABC):
|
|||||||
"""Used in the client to determine accessibility while playing and to populate tooltips."""
|
"""Used in the client to determine accessibility while playing and to populate tooltips."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||||
|
"""Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RuleData(ABC):
|
class RuleData(ABC):
|
||||||
@@ -104,6 +113,11 @@ class BeatMissionsEntryRule(EntryRule):
|
|||||||
resolved_reqs
|
resolved_reqs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||||
|
if len(self.missions_to_beat) > 0:
|
||||||
|
return self.missions_to_beat[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BeatMissionsRuleData(RuleData):
|
class BeatMissionsRuleData(RuleData):
|
||||||
@@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule):
|
|||||||
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
|
def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.missions_to_count = missions_to_count
|
self.missions_to_count = missions_to_count
|
||||||
if target_amount == -1 or target_amount > len(missions_to_count):
|
if target_amount <= -1 or target_amount > len(missions_to_count):
|
||||||
self.target_amount = len(missions_to_count)
|
self.target_amount = len(missions_to_count)
|
||||||
else:
|
else:
|
||||||
self.target_amount = target_amount
|
self.target_amount = target_amount
|
||||||
@@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule):
|
|||||||
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
|
return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based
|
||||||
|
|
||||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||||
return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count)
|
if self.target_amount == 0:
|
||||||
|
return always_true
|
||||||
|
|
||||||
|
beat_items = [mission.beat_item() for mission in self.missions_to_count]
|
||||||
|
def count_missions(state: CollectionState) -> bool:
|
||||||
|
count = 0
|
||||||
|
for mission in range(len(self.missions_to_count)):
|
||||||
|
if state.has(beat_items[mission], player):
|
||||||
|
count += 1
|
||||||
|
if count == self.target_amount:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return count_missions
|
||||||
|
|
||||||
def to_slot_data(self) -> RuleData:
|
def to_slot_data(self) -> RuleData:
|
||||||
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
|
resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs]
|
||||||
@@ -166,6 +193,11 @@ class CountMissionsEntryRule(EntryRule):
|
|||||||
resolved_reqs
|
resolved_reqs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||||
|
if self.target_amount > 0 and self.target_amount == len(self.missions_to_count):
|
||||||
|
return self.missions_to_count[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CountMissionsRuleData(RuleData):
|
class CountMissionsRuleData(RuleData):
|
||||||
@@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule):
|
|||||||
self.rule_id = rule_id
|
self.rule_id = rule_id
|
||||||
self.rules_to_check = rules_to_check
|
self.rules_to_check = rules_to_check
|
||||||
self.min_depth = -1
|
self.min_depth = -1
|
||||||
if target_amount == -1 or target_amount > len(rules_to_check):
|
if target_amount <= -1 or target_amount > len(rules_to_check):
|
||||||
self.target_amount = len(rules_to_check)
|
self.target_amount = len(rules_to_check)
|
||||||
else:
|
else:
|
||||||
self.target_amount = target_amount
|
self.target_amount = target_amount
|
||||||
|
|
||||||
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
|
def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool:
|
||||||
return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check)
|
if len(self.rules_to_check) == 0:
|
||||||
|
return True
|
||||||
|
count = 0
|
||||||
|
for rule in self.rules_to_check:
|
||||||
|
if rule.is_fulfilled(beaten_missions, in_region_check):
|
||||||
|
count += 1
|
||||||
|
if count == self.target_amount:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
|
def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int:
|
||||||
if len(self.rules_to_check) == 0:
|
if len(self.rules_to_check) == 0:
|
||||||
@@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule):
|
|||||||
|
|
||||||
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
def to_lambda(self, player: int) -> Callable[[CollectionState], bool]:
|
||||||
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
|
sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check]
|
||||||
return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas)
|
if self.target_amount == 0:
|
||||||
|
return always_true
|
||||||
|
if len(sub_lambdas) == 1:
|
||||||
|
return sub_lambdas[0]
|
||||||
|
|
||||||
|
def count_rules(state: CollectionState) -> bool:
|
||||||
|
count = 0
|
||||||
|
for sub_lambda in sub_lambdas:
|
||||||
|
if sub_lambda(state):
|
||||||
|
count += 1
|
||||||
|
if count == self.target_amount:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return count_rules
|
||||||
|
|
||||||
def to_slot_data(self) -> SubRuleRuleData:
|
def to_slot_data(self) -> SubRuleRuleData:
|
||||||
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
|
sub_rules = [rule.to_slot_data() for rule in self.rules_to_check]
|
||||||
@@ -245,6 +299,14 @@ class SubRuleEntryRule(EntryRule):
|
|||||||
self.target_amount
|
self.target_amount
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||||
|
if self.target_amount > 0 and self.target_amount == len(self.rules_to_check):
|
||||||
|
for sub_rule in self.rules_to_check:
|
||||||
|
mandatory_mission = sub_rule.find_mandatory_mission()
|
||||||
|
if mandatory_mission is not None:
|
||||||
|
return mandatory_mission
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SubRuleRuleData(RuleData):
|
class SubRuleRuleData(RuleData):
|
||||||
@@ -363,6 +425,9 @@ class ItemEntryRule(EntryRule):
|
|||||||
visual_reqs
|
visual_reqs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def find_mandatory_mission(self) -> SC2MOGenMission | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ItemRuleData(RuleData):
|
class ItemRuleData(RuleData):
|
||||||
|
|||||||
@@ -491,23 +491,60 @@ def make_connections(mission_order: SC2MOGenMissionOrder, world: 'SC2World'):
|
|||||||
for layout in campaign.layouts:
|
for layout in campaign.layouts:
|
||||||
for mission in layout.missions:
|
for mission in layout.missions:
|
||||||
if not mission.option_empty:
|
if not mission.option_empty:
|
||||||
|
mission_uses_rule = mission.entry_rule.target_amount > 0
|
||||||
mission_rule = mission.entry_rule.to_lambda(player)
|
mission_rule = mission.entry_rule.to_lambda(player)
|
||||||
|
mandatory_prereq = mission.entry_rule.find_mandatory_mission()
|
||||||
# Only layout entrances need to consider campaign & layout prerequisites
|
# Only layout entrances need to consider campaign & layout prerequisites
|
||||||
if mission.option_entrance:
|
if mission.option_entrance:
|
||||||
campaign_rule = mission.parent().parent().entry_rule.to_lambda(player)
|
campaign_uses_rule = campaign.entry_rule.target_amount > 0
|
||||||
layout_rule = mission.parent().entry_rule.to_lambda(player)
|
campaign_rule = campaign.entry_rule.to_lambda(player)
|
||||||
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \
|
layout_uses_rule = layout.entry_rule.target_amount > 0
|
||||||
campaign_rule(state) and layout_rule(state) and mission_rule(state)
|
layout_rule = layout.entry_rule.to_lambda(player)
|
||||||
else:
|
|
||||||
|
# Any mandatory prerequisite mission is good enough
|
||||||
|
mandatory_prereq = campaign.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
|
||||||
|
mandatory_prereq = layout.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq
|
||||||
|
|
||||||
|
# Avoid calling obviously unused lambdas
|
||||||
|
if campaign_uses_rule:
|
||||||
|
if layout_uses_rule:
|
||||||
|
if mission_uses_rule:
|
||||||
|
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \
|
||||||
|
campaign_rule(state) and layout_rule(state) and mission_rule(state)
|
||||||
|
else:
|
||||||
|
unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule: \
|
||||||
|
campaign_rule(state) and layout_rule(state)
|
||||||
|
else:
|
||||||
|
if mission_uses_rule:
|
||||||
|
unlock_rule = lambda state, campaign_rule=campaign_rule, mission_rule=mission_rule: \
|
||||||
|
campaign_rule(state) and mission_rule(state)
|
||||||
|
else:
|
||||||
|
unlock_rule = campaign_rule
|
||||||
|
elif layout_uses_rule:
|
||||||
|
if mission_uses_rule:
|
||||||
|
unlock_rule = lambda state, layout_rule=layout_rule, mission_rule=mission_rule: \
|
||||||
|
layout_rule(state) and mission_rule(state)
|
||||||
|
else:
|
||||||
|
unlock_rule = layout_rule
|
||||||
|
elif mission_uses_rule:
|
||||||
|
unlock_rule = mission_rule
|
||||||
|
else:
|
||||||
|
unlock_rule = None
|
||||||
|
elif mission_uses_rule:
|
||||||
unlock_rule = mission_rule
|
unlock_rule = mission_rule
|
||||||
# Individually connect to previous missions
|
else:
|
||||||
for prev_mission in mission.prev:
|
unlock_rule = None
|
||||||
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name,
|
|
||||||
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
|
# Connect to a discovered mandatory mission if possible
|
||||||
# If there are no previous missions, connect to Menu instead
|
if mandatory_prereq is not None:
|
||||||
if len(mission.prev) == 0:
|
connect(world, names, mandatory_prereq.mission.mission_name, mission.mission.mission_name, unlock_rule)
|
||||||
connect(world, names, "Menu", mission.mission.mission_name,
|
else:
|
||||||
lambda state, unlock_rule=unlock_rule: unlock_rule(state))
|
# If no mission is known to be mandatory, connect to all previous missions instead
|
||||||
|
for prev_mission in mission.prev:
|
||||||
|
connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name, unlock_rule)
|
||||||
|
# As a last resort connect to Menu
|
||||||
|
if len(mission.prev) == 0:
|
||||||
|
connect(world, names, "Menu", mission.mission.mission_name, unlock_rule)
|
||||||
|
|
||||||
|
|
||||||
def connect(world: 'SC2World', used_names: Dict[str, int], source: str, target: str,
|
def connect(world: 'SC2World', used_names: Dict[str, int], source: str, target: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user