diff --git a/worlds/sc2/mission_order/entry_rules.py b/worlds/sc2/mission_order/entry_rules.py index cb3afb37..afa872de 100644 --- a/worlds/sc2/mission_order/entry_rules.py +++ b/worlds/sc2/mission_order/entry_rules.py @@ -10,6 +10,10 @@ from BaseClasses import CollectionState if TYPE_CHECKING: from .nodes import SC2MOGenMission +def always_true(state: CollectionState) -> bool: + """Helper method to avoid creating trivial lambdas""" + return True + class EntryRule(ABC): buffer_fulfilled: bool @@ -60,6 +64,11 @@ class EntryRule(ABC): """Used in the client to determine accessibility while playing and to populate tooltips.""" pass + @abstractmethod + def find_mandatory_mission(self) -> SC2MOGenMission | None: + """Should return any mission that is mandatory to fulfill the entry rule, or `None` if there is no such mission.""" + return None + @dataclass class RuleData(ABC): @@ -103,6 +112,11 @@ class BeatMissionsEntryRule(EntryRule): mission_ids, resolved_reqs ) + + def find_mandatory_mission(self) -> SC2MOGenMission | None: + if len(self.missions_to_beat) > 0: + return self.missions_to_beat[0] + return None @dataclass @@ -140,7 +154,7 @@ class CountMissionsEntryRule(EntryRule): def __init__(self, missions_to_count: List[SC2MOGenMission], target_amount: int, visual_reqs: List[Union[str, SC2MOGenMission]]): super().__init__() self.missions_to_count = missions_to_count - if target_amount == -1 or target_amount > len(missions_to_count): + if target_amount <= -1 or target_amount > len(missions_to_count): self.target_amount = len(missions_to_count) else: self.target_amount = target_amount @@ -155,7 +169,20 @@ class CountMissionsEntryRule(EntryRule): return max(mission_depth, self.target_amount - 1) # -1 because depth is zero-based but amount is one-based def to_lambda(self, player: int) -> Callable[[CollectionState], bool]: - return lambda state: self.target_amount <= sum(state.has(mission.beat_item(), player) for mission in self.missions_to_count) + if self.target_amount == 0: + return always_true + + beat_items = [mission.beat_item() for mission in self.missions_to_count] + def count_missions(state: CollectionState) -> bool: + count = 0 + for mission in range(len(self.missions_to_count)): + if state.has(beat_items[mission], player): + count += 1 + if count == self.target_amount: + return True + return False + + return count_missions def to_slot_data(self) -> RuleData: resolved_reqs: List[Union[str, int]] = [req if isinstance(req, str) else req.mission.id for req in self.visual_reqs] @@ -165,6 +192,11 @@ class CountMissionsEntryRule(EntryRule): self.target_amount, resolved_reqs ) + + def find_mandatory_mission(self) -> SC2MOGenMission | None: + if self.target_amount > 0 and self.target_amount == len(self.missions_to_count): + return self.missions_to_count[0] + return None @dataclass @@ -216,13 +248,21 @@ class SubRuleEntryRule(EntryRule): self.rule_id = rule_id self.rules_to_check = rules_to_check self.min_depth = -1 - if target_amount == -1 or target_amount > len(rules_to_check): + if target_amount <= -1 or target_amount > len(rules_to_check): self.target_amount = len(rules_to_check) else: self.target_amount = target_amount def _is_fulfilled(self, beaten_missions: Set[SC2MOGenMission], in_region_check: bool) -> bool: - return self.target_amount <= sum(rule.is_fulfilled(beaten_missions, in_region_check) for rule in self.rules_to_check) + if len(self.rules_to_check) == 0: + return True + count = 0 + for rule in self.rules_to_check: + if rule.is_fulfilled(beaten_missions, in_region_check): + count += 1 + if count == self.target_amount: + return True + return False def _get_depth(self, beaten_missions: Set[SC2MOGenMission]) -> int: if len(self.rules_to_check) == 0: @@ -235,7 +275,21 @@ class SubRuleEntryRule(EntryRule): def to_lambda(self, player: int) -> Callable[[CollectionState], bool]: sub_lambdas = [rule.to_lambda(player) for rule in self.rules_to_check] - return lambda state, sub_lambdas=sub_lambdas: self.target_amount <= sum(sub_lambda(state) for sub_lambda in sub_lambdas) + if self.target_amount == 0: + return always_true + if len(sub_lambdas) == 1: + return sub_lambdas[0] + + def count_rules(state: CollectionState) -> bool: + count = 0 + for sub_lambda in sub_lambdas: + if sub_lambda(state): + count += 1 + if count == self.target_amount: + return True + return False + + return count_rules def to_slot_data(self) -> SubRuleRuleData: sub_rules = [rule.to_slot_data() for rule in self.rules_to_check] @@ -244,6 +298,14 @@ class SubRuleEntryRule(EntryRule): sub_rules, self.target_amount ) + + def find_mandatory_mission(self) -> SC2MOGenMission | None: + if self.target_amount > 0 and self.target_amount == len(self.rules_to_check): + for sub_rule in self.rules_to_check: + mandatory_mission = sub_rule.find_mandatory_mission() + if mandatory_mission is not None: + return mandatory_mission + return None @dataclass @@ -362,6 +424,9 @@ class ItemEntryRule(EntryRule): item_ids, visual_reqs ) + + def find_mandatory_mission(self) -> SC2MOGenMission | None: + return None @dataclass diff --git a/worlds/sc2/mission_order/generation.py b/worlds/sc2/mission_order/generation.py index 5582d7c3..928c0a45 100644 --- a/worlds/sc2/mission_order/generation.py +++ b/worlds/sc2/mission_order/generation.py @@ -491,23 +491,60 @@ def make_connections(mission_order: SC2MOGenMissionOrder, world: 'SC2World'): for layout in campaign.layouts: for mission in layout.missions: if not mission.option_empty: + mission_uses_rule = mission.entry_rule.target_amount > 0 mission_rule = mission.entry_rule.to_lambda(player) + mandatory_prereq = mission.entry_rule.find_mandatory_mission() # Only layout entrances need to consider campaign & layout prerequisites if mission.option_entrance: - campaign_rule = mission.parent().parent().entry_rule.to_lambda(player) - layout_rule = mission.parent().entry_rule.to_lambda(player) - unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \ - campaign_rule(state) and layout_rule(state) and mission_rule(state) - else: + campaign_uses_rule = campaign.entry_rule.target_amount > 0 + campaign_rule = campaign.entry_rule.to_lambda(player) + layout_uses_rule = layout.entry_rule.target_amount > 0 + layout_rule = layout.entry_rule.to_lambda(player) + + # Any mandatory prerequisite mission is good enough + mandatory_prereq = campaign.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq + mandatory_prereq = layout.entry_rule.find_mandatory_mission() if mandatory_prereq is None else mandatory_prereq + + # Avoid calling obviously unused lambdas + if campaign_uses_rule: + if layout_uses_rule: + if mission_uses_rule: + unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule, mission_rule=mission_rule: \ + campaign_rule(state) and layout_rule(state) and mission_rule(state) + else: + unlock_rule = lambda state, campaign_rule=campaign_rule, layout_rule=layout_rule: \ + campaign_rule(state) and layout_rule(state) + else: + if mission_uses_rule: + unlock_rule = lambda state, campaign_rule=campaign_rule, mission_rule=mission_rule: \ + campaign_rule(state) and mission_rule(state) + else: + unlock_rule = campaign_rule + elif layout_uses_rule: + if mission_uses_rule: + unlock_rule = lambda state, layout_rule=layout_rule, mission_rule=mission_rule: \ + layout_rule(state) and mission_rule(state) + else: + unlock_rule = layout_rule + elif mission_uses_rule: + unlock_rule = mission_rule + else: + unlock_rule = None + elif mission_uses_rule: unlock_rule = mission_rule - # Individually connect to previous missions - for prev_mission in mission.prev: - connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name, - lambda state, unlock_rule=unlock_rule: unlock_rule(state)) - # If there are no previous missions, connect to Menu instead - if len(mission.prev) == 0: - connect(world, names, "Menu", mission.mission.mission_name, - lambda state, unlock_rule=unlock_rule: unlock_rule(state)) + else: + unlock_rule = None + + # Connect to a discovered mandatory mission if possible + if mandatory_prereq is not None: + connect(world, names, mandatory_prereq.mission.mission_name, mission.mission.mission_name, unlock_rule) + else: + # If no mission is known to be mandatory, connect to all previous missions instead + for prev_mission in mission.prev: + connect(world, names, prev_mission.mission.mission_name, mission.mission.mission_name, unlock_rule) + # As a last resort connect to Menu + if len(mission.prev) == 0: + connect(world, names, "Menu", mission.mission.mission_name, unlock_rule) def connect(world: 'SC2World', used_names: Dict[str, int], source: str, target: str,