diff --git a/worlds/sc2wol/PoolFilter.py b/worlds/sc2wol/PoolFilter.py index e2d76d88..91fef13a 100644 --- a/worlds/sc2wol/PoolFilter.py +++ b/worlds/sc2wol/PoolFilter.py @@ -160,22 +160,30 @@ class ValidInventory: if item in cascade_keys: items_to_remove = self.cascade_removal_map[item] transient_items = [] + cascade_failure = False while len(items_to_remove) > 0: item_to_remove = items_to_remove.pop() + transient_items.append(item_to_remove) if item_to_remove not in inventory: - continue + if units_always_have_upgrades and item_to_remove in locked_items: + cascade_failure = True + break + else: + continue success = attempt_removal(item_to_remove) - if success: - transient_items.append(item_to_remove) - elif units_always_have_upgrades: - # Lock all associated items if any of them cannot be removed + if not success and units_always_have_upgrades: + cascade_failure = True transient_items += items_to_remove - for transient_item in transient_items: - if transient_item not in inventory and transient_item not in locked_items: - locked_items.append(transient_item) - if transient_item.classification in (ItemClassification.progression, ItemClassification.progression_skip_balancing): - self.logical_inventory.add(transient_item.name) break + # Lock all associated items if any of them cannot be removed on Units Always Have Upgrades + if cascade_failure: + for transient_item in transient_items: + if transient_item in inventory: + inventory.remove(transient_item) + if transient_item not in locked_items: + locked_items.append(transient_item) + if transient_item.classification in (ItemClassification.progression, ItemClassification.progression_skip_balancing): + self.logical_inventory.add(transient_item.name) else: attempt_removal(item) diff --git a/worlds/sc2wol/__init__.py b/worlds/sc2wol/__init__.py index dc77792d..2f2647cb 100644 --- a/worlds/sc2wol/__init__.py +++ b/worlds/sc2wol/__init__.py @@ -131,7 +131,7 @@ def get_excluded_items(self: SC2WoLWorld, world: MultiWorld, player: int) -> Set def assign_starter_items(world: MultiWorld, player: int, excluded_items: Set[str], locked_locations: List[str]) -> List[Item]: non_local_items = world.non_local_items[player].value if get_option_value(world, player, "early_unit"): - local_basic_unit = tuple(item for item in get_basic_units(world, player) if item not in non_local_items) + local_basic_unit = tuple(item for item in get_basic_units(world, player) if item not in non_local_items and item not in excluded_items) if not local_basic_unit: raise Exception("At least one basic unit must be local")