GER: Move EntranceLookup onto ERPlacementState. Improve usefulness of on_connect. (#4904)

Co-authored-by: Exempt-Medic <60412657+Exempt-Medic@users.noreply.github.com>
This commit is contained in:
BadMagic100
2025-07-25 11:55:22 -07:00
committed by GitHub
parent e5815ae5a2
commit 88e8e2408b
2 changed files with 122 additions and 38 deletions

View File

@@ -52,13 +52,15 @@ class EntranceLookup:
_coupled: bool _coupled: bool
_usable_exits: set[Entrance] _usable_exits: set[Entrance]
def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance]): def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance], targets: Iterable[Entrance]):
self.dead_ends = EntranceLookup.GroupLookup() self.dead_ends = EntranceLookup.GroupLookup()
self.others = EntranceLookup.GroupLookup() self.others = EntranceLookup.GroupLookup()
self._random = rng self._random = rng
self._expands_graph_cache = {} self._expands_graph_cache = {}
self._coupled = coupled self._coupled = coupled
self._usable_exits = usable_exits self._usable_exits = usable_exits
for target in targets:
self.add(target)
def _can_expand_graph(self, entrance: Entrance) -> bool: def _can_expand_graph(self, entrance: Entrance) -> bool:
""" """
@@ -121,7 +123,14 @@ class EntranceLookup:
dead_end: bool, dead_end: bool,
preserve_group_order: bool preserve_group_order: bool
) -> Iterable[Entrance]: ) -> Iterable[Entrance]:
"""
Gets available targets for the requested groups
:param groups: The groups to find targets for
:param dead_end: Whether to find dead ends. If false, finds non-dead-ends
:param preserve_group_order: Whether to preserve the group order in the returned iterable. If true, a sequence
like AAABBB is guaranteed. If false, groups can be interleaved, e.g. BAABAB.
"""
lookup = self.dead_ends if dead_end else self.others lookup = self.dead_ends if dead_end else self.others
if preserve_group_order: if preserve_group_order:
for group in groups: for group in groups:
@@ -132,6 +141,27 @@ class EntranceLookup:
self._random.shuffle(ret) self._random.shuffle(ret)
return ret return ret
def find_target(self, name: str, group: int | None = None, dead_end: bool | None = None) -> Entrance | None:
"""
Finds a specific target in the lookup, if it is present.
:param name: The name of the target
:param group: The target's group. Providing this will make the lookup faster, but can be omitted if it is not
known ahead of time for some reason.
:param dead_end: Whether the target is a dead end. Providing this will make the lookup faster, but can be
omitted if this is not known ahead of time (much more likely)
"""
if dead_end is None:
return (found
if (found := self.find_target(name, group, True))
else self.find_target(name, group, False))
lookup = self.dead_ends if dead_end else self.others
targets_to_check = lookup if group is None else lookup[group]
for target in targets_to_check:
if target.name == name:
return target
return None
def __len__(self): def __len__(self):
return len(self.dead_ends) + len(self.others) return len(self.dead_ends) + len(self.others)
@@ -146,15 +176,18 @@ class ERPlacementState:
"""The world which is having its entrances randomized""" """The world which is having its entrances randomized"""
collection_state: CollectionState collection_state: CollectionState
"""The CollectionState backing the entrance randomization logic""" """The CollectionState backing the entrance randomization logic"""
entrance_lookup: EntranceLookup
"""A lookup table of all unconnected ER targets"""
coupled: bool coupled: bool
"""Whether entrance randomization is operating in coupled mode""" """Whether entrance randomization is operating in coupled mode"""
def __init__(self, world: World, coupled: bool): def __init__(self, world: World, entrance_lookup: EntranceLookup, coupled: bool):
self.placements = [] self.placements = []
self.pairings = [] self.pairings = []
self.world = world self.world = world
self.coupled = coupled self.coupled = coupled
self.collection_state = world.multiworld.get_all_state(False, True) self.collection_state = world.multiworld.get_all_state(False, True)
self.entrance_lookup = entrance_lookup
@property @property
def placed_regions(self) -> set[Region]: def placed_regions(self) -> set[Region]:
@@ -182,6 +215,7 @@ class ERPlacementState:
self.collection_state.stale[self.world.player] = True self.collection_state.stale[self.world.player] = True
self.placements.append(source_exit) self.placements.append(source_exit)
self.pairings.append((source_exit.name, target_entrance.name)) self.pairings.append((source_exit.name, target_entrance.name))
self.entrance_lookup.remove(target_entrance)
def test_speculative_connection(self, source_exit: Entrance, target_entrance: Entrance, def test_speculative_connection(self, source_exit: Entrance, target_entrance: Entrance,
usable_exits: set[Entrance]) -> bool: usable_exits: set[Entrance]) -> bool:
@@ -311,7 +345,7 @@ def randomize_entrances(
preserve_group_order: bool = False, preserve_group_order: bool = False,
er_targets: list[Entrance] | None = None, er_targets: list[Entrance] | None = None,
exits: list[Entrance] | None = None, exits: list[Entrance] | None = None,
on_connect: Callable[[ERPlacementState, list[Entrance]], None] | None = None on_connect: Callable[[ERPlacementState, list[Entrance], list[Entrance]], bool | None] | None = None
) -> ERPlacementState: ) -> ERPlacementState:
""" """
Randomizes Entrances for a single world in the multiworld. Randomizes Entrances for a single world in the multiworld.
@@ -328,14 +362,18 @@ def randomize_entrances(
:param exits: The list of exits (Entrance objects with no target region) to use for randomization. :param exits: The list of exits (Entrance objects with no target region) to use for randomization.
Remember to be deterministic! If not provided, automatically discovers all valid exits in your world. Remember to be deterministic! If not provided, automatically discovers all valid exits in your world.
:param on_connect: A callback function which allows specifying side effects after a placement is completed :param on_connect: A callback function which allows specifying side effects after a placement is completed
successfully and the underlying collection state has been updated. successfully and the underlying collection state has been updated. The arguments are
1. The ER state
2. The exits placed in this placement pass
3. The entrances they were connected to.
If you use on_connect to make additional placements, you are expected to return True to inform
GER that an additional sweep is needed.
""" """
if not world.explicit_indirect_conditions: if not world.explicit_indirect_conditions:
raise EntranceRandomizationError("Entrance randomization requires explicit indirect conditions in order " raise EntranceRandomizationError("Entrance randomization requires explicit indirect conditions in order "
+ "to correctly analyze whether dead end regions can be required in logic.") + "to correctly analyze whether dead end regions can be required in logic.")
start_time = time.perf_counter() start_time = time.perf_counter()
er_state = ERPlacementState(world, coupled)
# similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility # similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility
perform_validity_check = True perform_validity_check = True
@@ -351,23 +389,25 @@ def randomize_entrances(
# used when membership checks are needed on the exit list, e.g. speculative sweep # used when membership checks are needed on the exit list, e.g. speculative sweep
exits_set = set(exits) exits_set = set(exits)
entrance_lookup = EntranceLookup(world.random, coupled, exits_set)
for entrance in er_targets:
entrance_lookup.add(entrance)
er_state = ERPlacementState(
world,
EntranceLookup(world.random, coupled, exits_set, er_targets),
coupled
)
# place the menu region and connected start region(s) # place the menu region and connected start region(s)
er_state.collection_state.update_reachable_regions(world.player) er_state.collection_state.update_reachable_regions(world.player)
def do_placement(source_exit: Entrance, target_entrance: Entrance) -> None: def do_placement(source_exit: Entrance, target_entrance: Entrance) -> None:
placed_exits, removed_entrances = er_state.connect(source_exit, target_entrance) placed_exits, paired_entrances = er_state.connect(source_exit, target_entrance)
# remove the placed targets from consideration
for entrance in removed_entrances:
entrance_lookup.remove(entrance)
# propagate new connections # propagate new connections
er_state.collection_state.update_reachable_regions(world.player) er_state.collection_state.update_reachable_regions(world.player)
er_state.collection_state.sweep_for_advancements() er_state.collection_state.sweep_for_advancements()
if on_connect: if on_connect:
on_connect(er_state, placed_exits) change = on_connect(er_state, placed_exits, paired_entrances)
if change:
er_state.collection_state.update_reachable_regions(world.player)
er_state.collection_state.sweep_for_advancements()
def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_exits: list[Entrance]) -> bool: def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_exits: list[Entrance]) -> bool:
# speculative sweep is expensive. We currently only do it as a last resort, if we might cap off the graph # speculative sweep is expensive. We currently only do it as a last resort, if we might cap off the graph
@@ -388,12 +428,12 @@ def randomize_entrances(
# check to see if we are proposing the last placement # check to see if we are proposing the last placement
if not coupled: if not coupled:
# in uncoupled, this check is easy as there will only be one target. # in uncoupled, this check is easy as there will only be one target.
is_last_placement = len(entrance_lookup) == 1 is_last_placement = len(er_state.entrance_lookup) == 1
else: else:
# a bit harder, there may be 1 or 2 targets depending on if the exit to place is one way or two way. # a bit harder, there may be 1 or 2 targets depending on if the exit to place is one way or two way.
# if it is two way, we can safely assume that one of the targets is the logical pair of the exit. # if it is two way, we can safely assume that one of the targets is the logical pair of the exit.
desired_target_count = 2 if placeable_exits[0].randomization_type == EntranceType.TWO_WAY else 1 desired_target_count = 2 if placeable_exits[0].randomization_type == EntranceType.TWO_WAY else 1
is_last_placement = len(entrance_lookup) == desired_target_count is_last_placement = len(er_state.entrance_lookup) == desired_target_count
# if it's not the last placement, we need a sweep # if it's not the last placement, we need a sweep
return not is_last_placement return not is_last_placement
@@ -402,7 +442,7 @@ def randomize_entrances(
placeable_exits = er_state.find_placeable_exits(perform_validity_check, exits) placeable_exits = er_state.find_placeable_exits(perform_validity_check, exits)
for source_exit in placeable_exits: for source_exit in placeable_exits:
target_groups = target_group_lookup[source_exit.randomization_group] target_groups = target_group_lookup[source_exit.randomization_group]
for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order): for target_entrance in er_state.entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
# when requiring new exits, ideally we would like to make it so that every placement increases # when requiring new exits, ideally we would like to make it so that every placement increases
# (or keeps the same number of) reachable exits. The goal is to continue to expand the search space # (or keeps the same number of) reachable exits. The goal is to continue to expand the search space
# so that we do not crash. In the interest of performance and bias reduction, generally, just checking # so that we do not crash. In the interest of performance and bias reduction, generally, just checking
@@ -420,7 +460,7 @@ def randomize_entrances(
else: else:
# no source exits had any valid target so this stage is deadlocked. retries may be implemented if early # no source exits had any valid target so this stage is deadlocked. retries may be implemented if early
# deadlocking is a frequent issue. # deadlocking is a frequent issue.
lookup = entrance_lookup.dead_ends if dead_end else entrance_lookup.others lookup = er_state.entrance_lookup.dead_ends if dead_end else er_state.entrance_lookup.others
# if we're in a stage where we're trying to get to new regions, we could also enter this # if we're in a stage where we're trying to get to new regions, we could also enter this
# branch in a success state (when all regions of the preferred type have been placed, but there are still # branch in a success state (when all regions of the preferred type have been placed, but there are still
@@ -466,21 +506,21 @@ def randomize_entrances(
f"All unplaced exits: {unplaced_exits}") f"All unplaced exits: {unplaced_exits}")
# stage 1 - try to place all the non-dead-end entrances # stage 1 - try to place all the non-dead-end entrances
while entrance_lookup.others: while er_state.entrance_lookup.others:
if not find_pairing(dead_end=False, require_new_exits=True): if not find_pairing(dead_end=False, require_new_exits=True):
break break
# stage 2 - try to place all the dead-end entrances # stage 2 - try to place all the dead-end entrances
while entrance_lookup.dead_ends: while er_state.entrance_lookup.dead_ends:
if not find_pairing(dead_end=True, require_new_exits=True): if not find_pairing(dead_end=True, require_new_exits=True):
break break
# stage 3 - all the regions should be placed at this point. We now need to connect dangling edges # stage 3 - all the regions should be placed at this point. We now need to connect dangling edges
# stage 3a - get the rest of the dead ends (e.g. second entrances into already-visited regions) # stage 3a - get the rest of the dead ends (e.g. second entrances into already-visited regions)
# doing this before the non-dead-ends is important to ensure there are enough connections to # doing this before the non-dead-ends is important to ensure there are enough connections to
# go around # go around
while entrance_lookup.dead_ends: while er_state.entrance_lookup.dead_ends:
find_pairing(dead_end=True, require_new_exits=False) find_pairing(dead_end=True, require_new_exits=False)
# stage 3b - tie all the other loose ends connecting visited regions to each other # stage 3b - tie all the other loose ends connecting visited regions to each other
while entrance_lookup.others: while er_state.entrance_lookup.others:
find_pairing(dead_end=False, require_new_exits=False) find_pairing(dead_end=False, require_new_exits=False)
running_time = time.perf_counter() - start_time running_time = time.perf_counter() - start_time

View File

@@ -69,11 +69,9 @@ class TestEntranceLookup(unittest.TestCase):
exits_set = set([ex for region in multiworld.get_regions(1) exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region]) for ex in region.exits if not ex.connected_region])
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1) er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region] for entrance in region.entrances if not entrance.parent_region]
for entrance in er_targets: lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
lookup.add(entrance)
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
False, False) False, False)
@@ -92,11 +90,9 @@ class TestEntranceLookup(unittest.TestCase):
exits_set = set([ex for region in multiworld.get_regions(1) exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region]) for ex in region.exits if not ex.connected_region])
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1) er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region] for entrance in region.entrances if not entrance.parent_region]
for entrance in er_targets: lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
lookup.add(entrance)
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
False, True) False, True)
@@ -112,12 +108,10 @@ class TestEntranceLookup(unittest.TestCase):
for ex in region.exits if not ex.connected_region for ex in region.exits if not ex.connected_region
and ex.name != "region20_right" and ex.name != "region21_left"]) and ex.name != "region20_right" and ex.name != "region21_left"])
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1) er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region and for entrance in region.entrances if not entrance.parent_region and
entrance.name != "region20_right" and entrance.name != "region21_left"] entrance.name != "region20_right" and entrance.name != "region21_left"]
for entrance in er_targets: lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
lookup.add(entrance)
# region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21 # region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21
# and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21, # and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21,
# the top entrance from region 15 should be considered a dead-end # the top entrance from region 15 should be considered a dead-end
@@ -129,6 +123,56 @@ class TestEntranceLookup(unittest.TestCase):
self.assertTrue(dead_end in lookup.dead_ends) self.assertTrue(dead_end in lookup.dead_ends)
self.assertEqual(len(lookup.dead_ends), 1) self.assertEqual(len(lookup.dead_ends), 1)
def test_find_target_by_name(self):
"""Tests that find_target can find the correct target by name only"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
target = lookup.find_target("region0_right")
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
self.assertIsNone(lookup.find_target("nonexistant"))
def test_find_target_by_name_and_group(self):
"""Tests that find_target can find the correct target by name and group"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
target = lookup.find_target("region0_right", ERTestGroups.RIGHT)
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
# wrong group
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.LEFT))
def test_find_target_by_name_and_group_and_category(self):
"""Tests that find_target can find the correct target by name, group, and dead-endedness"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
target = lookup.find_target("region0_right", ERTestGroups.RIGHT, False)
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
# wrong deadendedness
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.RIGHT, True))
class TestBakeTargetGroupLookup(unittest.TestCase): class TestBakeTargetGroupLookup(unittest.TestCase):
def test_lookup_generation(self): def test_lookup_generation(self):
multiworld = generate_test_multiworld() multiworld = generate_test_multiworld()
@@ -265,12 +309,12 @@ class TestRandomizeEntrances(unittest.TestCase):
generate_disconnected_region_grid(multiworld, 5) generate_disconnected_region_grid(multiworld, 5)
seen_placement_count = 0 seen_placement_count = 0
def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]): def verify_coupled(_: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
nonlocal seen_placement_count nonlocal seen_placement_count
seen_placement_count += len(placed_entrances) seen_placement_count += len(placed_exits)
self.assertEqual(2, len(placed_entrances)) self.assertEqual(2, len(placed_exits))
self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region) self.assertEqual(placed_exits[0].parent_region, placed_exits[1].connected_region)
self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region) self.assertEqual(placed_exits[1].parent_region, placed_exits[0].connected_region)
result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup, result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup,
on_connect=verify_coupled) on_connect=verify_coupled)
@@ -313,10 +357,10 @@ class TestRandomizeEntrances(unittest.TestCase):
generate_disconnected_region_grid(multiworld, 5) generate_disconnected_region_grid(multiworld, 5)
seen_placement_count = 0 seen_placement_count = 0
def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]): def verify_uncoupled(state: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
nonlocal seen_placement_count nonlocal seen_placement_count
seen_placement_count += len(placed_entrances) seen_placement_count += len(placed_exits)
self.assertEqual(1, len(placed_entrances)) self.assertEqual(1, len(placed_exits))
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup, result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup,
on_connect=verify_uncoupled) on_connect=verify_uncoupled)