mirror of
https://github.com/MarioSpore/Grinch-AP.git
synced 2025-10-20 11:51:32 -06:00
GER: Move EntranceLookup onto ERPlacementState. Improve usefulness of on_connect. (#4904)
Co-authored-by: Exempt-Medic <60412657+Exempt-Medic@users.noreply.github.com>
This commit is contained in:
@@ -52,13 +52,15 @@ class EntranceLookup:
|
||||
_coupled: bool
|
||||
_usable_exits: set[Entrance]
|
||||
|
||||
def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance]):
|
||||
def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance], targets: Iterable[Entrance]):
|
||||
self.dead_ends = EntranceLookup.GroupLookup()
|
||||
self.others = EntranceLookup.GroupLookup()
|
||||
self._random = rng
|
||||
self._expands_graph_cache = {}
|
||||
self._coupled = coupled
|
||||
self._usable_exits = usable_exits
|
||||
for target in targets:
|
||||
self.add(target)
|
||||
|
||||
def _can_expand_graph(self, entrance: Entrance) -> bool:
|
||||
"""
|
||||
@@ -121,7 +123,14 @@ class EntranceLookup:
|
||||
dead_end: bool,
|
||||
preserve_group_order: bool
|
||||
) -> Iterable[Entrance]:
|
||||
"""
|
||||
Gets available targets for the requested groups
|
||||
|
||||
:param groups: The groups to find targets for
|
||||
:param dead_end: Whether to find dead ends. If false, finds non-dead-ends
|
||||
:param preserve_group_order: Whether to preserve the group order in the returned iterable. If true, a sequence
|
||||
like AAABBB is guaranteed. If false, groups can be interleaved, e.g. BAABAB.
|
||||
"""
|
||||
lookup = self.dead_ends if dead_end else self.others
|
||||
if preserve_group_order:
|
||||
for group in groups:
|
||||
@@ -132,6 +141,27 @@ class EntranceLookup:
|
||||
self._random.shuffle(ret)
|
||||
return ret
|
||||
|
||||
def find_target(self, name: str, group: int | None = None, dead_end: bool | None = None) -> Entrance | None:
|
||||
"""
|
||||
Finds a specific target in the lookup, if it is present.
|
||||
|
||||
:param name: The name of the target
|
||||
:param group: The target's group. Providing this will make the lookup faster, but can be omitted if it is not
|
||||
known ahead of time for some reason.
|
||||
:param dead_end: Whether the target is a dead end. Providing this will make the lookup faster, but can be
|
||||
omitted if this is not known ahead of time (much more likely)
|
||||
"""
|
||||
if dead_end is None:
|
||||
return (found
|
||||
if (found := self.find_target(name, group, True))
|
||||
else self.find_target(name, group, False))
|
||||
lookup = self.dead_ends if dead_end else self.others
|
||||
targets_to_check = lookup if group is None else lookup[group]
|
||||
for target in targets_to_check:
|
||||
if target.name == name:
|
||||
return target
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dead_ends) + len(self.others)
|
||||
|
||||
@@ -146,15 +176,18 @@ class ERPlacementState:
|
||||
"""The world which is having its entrances randomized"""
|
||||
collection_state: CollectionState
|
||||
"""The CollectionState backing the entrance randomization logic"""
|
||||
entrance_lookup: EntranceLookup
|
||||
"""A lookup table of all unconnected ER targets"""
|
||||
coupled: bool
|
||||
"""Whether entrance randomization is operating in coupled mode"""
|
||||
|
||||
def __init__(self, world: World, coupled: bool):
|
||||
def __init__(self, world: World, entrance_lookup: EntranceLookup, coupled: bool):
|
||||
self.placements = []
|
||||
self.pairings = []
|
||||
self.world = world
|
||||
self.coupled = coupled
|
||||
self.collection_state = world.multiworld.get_all_state(False, True)
|
||||
self.entrance_lookup = entrance_lookup
|
||||
|
||||
@property
|
||||
def placed_regions(self) -> set[Region]:
|
||||
@@ -182,6 +215,7 @@ class ERPlacementState:
|
||||
self.collection_state.stale[self.world.player] = True
|
||||
self.placements.append(source_exit)
|
||||
self.pairings.append((source_exit.name, target_entrance.name))
|
||||
self.entrance_lookup.remove(target_entrance)
|
||||
|
||||
def test_speculative_connection(self, source_exit: Entrance, target_entrance: Entrance,
|
||||
usable_exits: set[Entrance]) -> bool:
|
||||
@@ -311,7 +345,7 @@ def randomize_entrances(
|
||||
preserve_group_order: bool = False,
|
||||
er_targets: list[Entrance] | None = None,
|
||||
exits: list[Entrance] | None = None,
|
||||
on_connect: Callable[[ERPlacementState, list[Entrance]], None] | None = None
|
||||
on_connect: Callable[[ERPlacementState, list[Entrance], list[Entrance]], bool | None] | None = None
|
||||
) -> ERPlacementState:
|
||||
"""
|
||||
Randomizes Entrances for a single world in the multiworld.
|
||||
@@ -328,14 +362,18 @@ def randomize_entrances(
|
||||
:param exits: The list of exits (Entrance objects with no target region) to use for randomization.
|
||||
Remember to be deterministic! If not provided, automatically discovers all valid exits in your world.
|
||||
:param on_connect: A callback function which allows specifying side effects after a placement is completed
|
||||
successfully and the underlying collection state has been updated.
|
||||
successfully and the underlying collection state has been updated. The arguments are
|
||||
1. The ER state
|
||||
2. The exits placed in this placement pass
|
||||
3. The entrances they were connected to.
|
||||
If you use on_connect to make additional placements, you are expected to return True to inform
|
||||
GER that an additional sweep is needed.
|
||||
"""
|
||||
if not world.explicit_indirect_conditions:
|
||||
raise EntranceRandomizationError("Entrance randomization requires explicit indirect conditions in order "
|
||||
+ "to correctly analyze whether dead end regions can be required in logic.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
er_state = ERPlacementState(world, coupled)
|
||||
# similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility
|
||||
perform_validity_check = True
|
||||
|
||||
@@ -351,23 +389,25 @@ def randomize_entrances(
|
||||
|
||||
# used when membership checks are needed on the exit list, e.g. speculative sweep
|
||||
exits_set = set(exits)
|
||||
entrance_lookup = EntranceLookup(world.random, coupled, exits_set)
|
||||
for entrance in er_targets:
|
||||
entrance_lookup.add(entrance)
|
||||
|
||||
er_state = ERPlacementState(
|
||||
world,
|
||||
EntranceLookup(world.random, coupled, exits_set, er_targets),
|
||||
coupled
|
||||
)
|
||||
# place the menu region and connected start region(s)
|
||||
er_state.collection_state.update_reachable_regions(world.player)
|
||||
|
||||
def do_placement(source_exit: Entrance, target_entrance: Entrance) -> None:
|
||||
placed_exits, removed_entrances = er_state.connect(source_exit, target_entrance)
|
||||
# remove the placed targets from consideration
|
||||
for entrance in removed_entrances:
|
||||
entrance_lookup.remove(entrance)
|
||||
placed_exits, paired_entrances = er_state.connect(source_exit, target_entrance)
|
||||
# propagate new connections
|
||||
er_state.collection_state.update_reachable_regions(world.player)
|
||||
er_state.collection_state.sweep_for_advancements()
|
||||
if on_connect:
|
||||
on_connect(er_state, placed_exits)
|
||||
change = on_connect(er_state, placed_exits, paired_entrances)
|
||||
if change:
|
||||
er_state.collection_state.update_reachable_regions(world.player)
|
||||
er_state.collection_state.sweep_for_advancements()
|
||||
|
||||
def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_exits: list[Entrance]) -> bool:
|
||||
# speculative sweep is expensive. We currently only do it as a last resort, if we might cap off the graph
|
||||
@@ -388,12 +428,12 @@ def randomize_entrances(
|
||||
# check to see if we are proposing the last placement
|
||||
if not coupled:
|
||||
# in uncoupled, this check is easy as there will only be one target.
|
||||
is_last_placement = len(entrance_lookup) == 1
|
||||
is_last_placement = len(er_state.entrance_lookup) == 1
|
||||
else:
|
||||
# a bit harder, there may be 1 or 2 targets depending on if the exit to place is one way or two way.
|
||||
# if it is two way, we can safely assume that one of the targets is the logical pair of the exit.
|
||||
desired_target_count = 2 if placeable_exits[0].randomization_type == EntranceType.TWO_WAY else 1
|
||||
is_last_placement = len(entrance_lookup) == desired_target_count
|
||||
is_last_placement = len(er_state.entrance_lookup) == desired_target_count
|
||||
# if it's not the last placement, we need a sweep
|
||||
return not is_last_placement
|
||||
|
||||
@@ -402,7 +442,7 @@ def randomize_entrances(
|
||||
placeable_exits = er_state.find_placeable_exits(perform_validity_check, exits)
|
||||
for source_exit in placeable_exits:
|
||||
target_groups = target_group_lookup[source_exit.randomization_group]
|
||||
for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
|
||||
for target_entrance in er_state.entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
|
||||
# when requiring new exits, ideally we would like to make it so that every placement increases
|
||||
# (or keeps the same number of) reachable exits. The goal is to continue to expand the search space
|
||||
# so that we do not crash. In the interest of performance and bias reduction, generally, just checking
|
||||
@@ -420,7 +460,7 @@ def randomize_entrances(
|
||||
else:
|
||||
# no source exits had any valid target so this stage is deadlocked. retries may be implemented if early
|
||||
# deadlocking is a frequent issue.
|
||||
lookup = entrance_lookup.dead_ends if dead_end else entrance_lookup.others
|
||||
lookup = er_state.entrance_lookup.dead_ends if dead_end else er_state.entrance_lookup.others
|
||||
|
||||
# if we're in a stage where we're trying to get to new regions, we could also enter this
|
||||
# branch in a success state (when all regions of the preferred type have been placed, but there are still
|
||||
@@ -466,21 +506,21 @@ def randomize_entrances(
|
||||
f"All unplaced exits: {unplaced_exits}")
|
||||
|
||||
# stage 1 - try to place all the non-dead-end entrances
|
||||
while entrance_lookup.others:
|
||||
while er_state.entrance_lookup.others:
|
||||
if not find_pairing(dead_end=False, require_new_exits=True):
|
||||
break
|
||||
# stage 2 - try to place all the dead-end entrances
|
||||
while entrance_lookup.dead_ends:
|
||||
while er_state.entrance_lookup.dead_ends:
|
||||
if not find_pairing(dead_end=True, require_new_exits=True):
|
||||
break
|
||||
# stage 3 - all the regions should be placed at this point. We now need to connect dangling edges
|
||||
# stage 3a - get the rest of the dead ends (e.g. second entrances into already-visited regions)
|
||||
# doing this before the non-dead-ends is important to ensure there are enough connections to
|
||||
# go around
|
||||
while entrance_lookup.dead_ends:
|
||||
while er_state.entrance_lookup.dead_ends:
|
||||
find_pairing(dead_end=True, require_new_exits=False)
|
||||
# stage 3b - tie all the other loose ends connecting visited regions to each other
|
||||
while entrance_lookup.others:
|
||||
while er_state.entrance_lookup.others:
|
||||
find_pairing(dead_end=False, require_new_exits=False)
|
||||
|
||||
running_time = time.perf_counter() - start_time
|
||||
|
@@ -69,11 +69,9 @@ class TestEntranceLookup(unittest.TestCase):
|
||||
exits_set = set([ex for region in multiworld.get_regions(1)
|
||||
for ex in region.exits if not ex.connected_region])
|
||||
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region]
|
||||
for entrance in er_targets:
|
||||
lookup.add(entrance)
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
|
||||
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
|
||||
False, False)
|
||||
@@ -92,11 +90,9 @@ class TestEntranceLookup(unittest.TestCase):
|
||||
exits_set = set([ex for region in multiworld.get_regions(1)
|
||||
for ex in region.exits if not ex.connected_region])
|
||||
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region]
|
||||
for entrance in er_targets:
|
||||
lookup.add(entrance)
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
|
||||
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
|
||||
False, True)
|
||||
@@ -112,12 +108,10 @@ class TestEntranceLookup(unittest.TestCase):
|
||||
for ex in region.exits if not ex.connected_region
|
||||
and ex.name != "region20_right" and ex.name != "region21_left"])
|
||||
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region and
|
||||
entrance.name != "region20_right" and entrance.name != "region21_left"]
|
||||
for entrance in er_targets:
|
||||
lookup.add(entrance)
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
# region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21
|
||||
# and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21,
|
||||
# the top entrance from region 15 should be considered a dead-end
|
||||
@@ -129,6 +123,56 @@ class TestEntranceLookup(unittest.TestCase):
|
||||
self.assertTrue(dead_end in lookup.dead_ends)
|
||||
self.assertEqual(len(lookup.dead_ends), 1)
|
||||
|
||||
def test_find_target_by_name(self):
|
||||
"""Tests that find_target can find the correct target by name only"""
|
||||
multiworld = generate_test_multiworld()
|
||||
generate_disconnected_region_grid(multiworld, 5)
|
||||
exits_set = set([ex for region in multiworld.get_regions(1)
|
||||
for ex in region.exits if not ex.connected_region])
|
||||
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region]
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
|
||||
target = lookup.find_target("region0_right")
|
||||
self.assertEqual(target.name, "region0_right")
|
||||
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
|
||||
self.assertIsNone(lookup.find_target("nonexistant"))
|
||||
|
||||
def test_find_target_by_name_and_group(self):
|
||||
"""Tests that find_target can find the correct target by name and group"""
|
||||
multiworld = generate_test_multiworld()
|
||||
generate_disconnected_region_grid(multiworld, 5)
|
||||
exits_set = set([ex for region in multiworld.get_regions(1)
|
||||
for ex in region.exits if not ex.connected_region])
|
||||
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region]
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
|
||||
target = lookup.find_target("region0_right", ERTestGroups.RIGHT)
|
||||
self.assertEqual(target.name, "region0_right")
|
||||
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
|
||||
# wrong group
|
||||
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.LEFT))
|
||||
|
||||
def test_find_target_by_name_and_group_and_category(self):
|
||||
"""Tests that find_target can find the correct target by name, group, and dead-endedness"""
|
||||
multiworld = generate_test_multiworld()
|
||||
generate_disconnected_region_grid(multiworld, 5)
|
||||
exits_set = set([ex for region in multiworld.get_regions(1)
|
||||
for ex in region.exits if not ex.connected_region])
|
||||
|
||||
er_targets = [entrance for region in multiworld.get_regions(1)
|
||||
for entrance in region.entrances if not entrance.parent_region]
|
||||
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
|
||||
|
||||
target = lookup.find_target("region0_right", ERTestGroups.RIGHT, False)
|
||||
self.assertEqual(target.name, "region0_right")
|
||||
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
|
||||
# wrong deadendedness
|
||||
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.RIGHT, True))
|
||||
|
||||
class TestBakeTargetGroupLookup(unittest.TestCase):
|
||||
def test_lookup_generation(self):
|
||||
multiworld = generate_test_multiworld()
|
||||
@@ -265,12 +309,12 @@ class TestRandomizeEntrances(unittest.TestCase):
|
||||
generate_disconnected_region_grid(multiworld, 5)
|
||||
seen_placement_count = 0
|
||||
|
||||
def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]):
|
||||
def verify_coupled(_: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
|
||||
nonlocal seen_placement_count
|
||||
seen_placement_count += len(placed_entrances)
|
||||
self.assertEqual(2, len(placed_entrances))
|
||||
self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region)
|
||||
self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region)
|
||||
seen_placement_count += len(placed_exits)
|
||||
self.assertEqual(2, len(placed_exits))
|
||||
self.assertEqual(placed_exits[0].parent_region, placed_exits[1].connected_region)
|
||||
self.assertEqual(placed_exits[1].parent_region, placed_exits[0].connected_region)
|
||||
|
||||
result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup,
|
||||
on_connect=verify_coupled)
|
||||
@@ -313,10 +357,10 @@ class TestRandomizeEntrances(unittest.TestCase):
|
||||
generate_disconnected_region_grid(multiworld, 5)
|
||||
seen_placement_count = 0
|
||||
|
||||
def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]):
|
||||
def verify_uncoupled(state: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
|
||||
nonlocal seen_placement_count
|
||||
seen_placement_count += len(placed_entrances)
|
||||
self.assertEqual(1, len(placed_entrances))
|
||||
seen_placement_count += len(placed_exits)
|
||||
self.assertEqual(1, len(placed_exits))
|
||||
|
||||
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup,
|
||||
on_connect=verify_uncoupled)
|
||||
|
Reference in New Issue
Block a user