diff --git a/entrance_rando.py b/entrance_rando.py index ab329edf..5ed2cd76 100644 --- a/entrance_rando.py +++ b/entrance_rando.py @@ -50,13 +50,15 @@ class EntranceLookup: _random: random.Random _expands_graph_cache: dict[Entrance, bool] _coupled: bool + _usable_exits: set[Entrance] - def __init__(self, rng: random.Random, coupled: bool): + def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance]): self.dead_ends = EntranceLookup.GroupLookup() self.others = EntranceLookup.GroupLookup() self._random = rng self._expands_graph_cache = {} self._coupled = coupled + self._usable_exits = usable_exits def _can_expand_graph(self, entrance: Entrance) -> bool: """ @@ -95,7 +97,8 @@ class EntranceLookup: # randomizable exits which are not reverse of the incoming entrance. # uncoupled mode is an exception because in this case going back in the door you just came in could # actually lead somewhere new - if not exit_.connected_region and (not self._coupled or exit_.name != entrance.name): + if (not exit_.connected_region and (not self._coupled or exit_.name != entrance.name) + and exit_ in self._usable_exits): self._expands_graph_cache[entrance] = True return True elif exit_.connected_region and exit_.connected_region not in visited: @@ -333,7 +336,6 @@ def randomize_entrances( start_time = time.perf_counter() er_state = ERPlacementState(world, coupled) - entrance_lookup = EntranceLookup(world.random, coupled) # similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility perform_validity_check = True @@ -349,6 +351,7 @@ def randomize_entrances( # used when membership checks are needed on the exit list, e.g. speculative sweep exits_set = set(exits) + entrance_lookup = EntranceLookup(world.random, coupled, exits_set) for entrance in er_targets: entrance_lookup.add(entrance) diff --git a/test/general/test_entrance_rando.py b/test/general/test_entrance_rando.py index 542b3b4b..56a059ec 100644 --- a/test/general/test_entrance_rando.py +++ b/test/general/test_entrance_rando.py @@ -65,8 +65,10 @@ class TestEntranceLookup(unittest.TestCase): """tests that get_targets shuffles targets between groups when requested""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region]) - lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True) + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] for entrance in er_targets: @@ -86,8 +88,10 @@ class TestEntranceLookup(unittest.TestCase): """tests that get_targets does not shuffle targets between groups when requested""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region]) - lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True) + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] for entrance in er_targets: @@ -99,6 +103,30 @@ class TestEntranceLookup(unittest.TestCase): group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group] self.assertEqual([ERTestGroups.TOP, ERTestGroups.BOTTOM], group_order) + def test_selective_dead_ends(self): + """test that entrances that EntranceLookup has not been told to consider are ignored when finding dead-ends""" + multiworld = generate_test_multiworld() + generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region + and ex.name != "region20_right" and ex.name != "region21_left"]) + + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) + er_targets = [entrance for region in multiworld.get_regions(1) + for entrance in region.entrances if not entrance.parent_region and + entrance.name != "region20_right" and entrance.name != "region21_left"] + for entrance in er_targets: + lookup.add(entrance) + # region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21 + # and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21, + # the top entrance from region 15 should be considered a dead-end + dead_end_region = multiworld.get_region("region20", 1) + for dead_end in dead_end_region.entrances: + if dead_end.name == "region20_top": + break + # there should be only this one dead-end + self.assertTrue(dead_end in lookup.dead_ends) + self.assertEqual(len(lookup.dead_ends), 1) class TestBakeTargetGroupLookup(unittest.TestCase): def test_lookup_generation(self):