diff --git a/entrance_rando.py b/entrance_rando.py index 1d2fbc2b..ab329edf 100644 --- a/entrance_rando.py +++ b/entrance_rando.py @@ -265,14 +265,19 @@ def bake_target_group_lookup(world: World, get_target_groups: Callable[[int], li return { group: get_target_groups(group) for group in unique_groups } -def disconnect_entrance_for_randomization(entrance: Entrance, target_group: int | None = None) -> None: +def disconnect_entrance_for_randomization(entrance: Entrance, target_group: int | None = None, + one_way_target_name: str | None = None) -> None: """ Given an entrance in a "vanilla" region graph, splits that entrance to prepare it for randomization - in randomize_entrances. This should be done after setting the type and group of the entrance. + in randomize_entrances. This should be done after setting the type and group of the entrance. Because it attempts + to meet strict entrance naming requirements for coupled mode, this function may produce unintuitive results when + called only on a single entrance; it produces eventually-correct outputs only after calling it on all entrances. :param entrance: The entrance which will be disconnected in preparation for randomization. :param target_group: The group to assign to the created ER target. If not specified, the group from the original entrance will be copied. + :param one_way_target_name: The name of the created ER target if `entrance` is one-way. This argument + is required for one-way entrances and is ignored otherwise. """ child_region = entrance.connected_region parent_region = entrance.parent_region @@ -287,8 +292,11 @@ def disconnect_entrance_for_randomization(entrance: Entrance, target_group: int # targets in the child region will be created when the other direction edge is disconnected target = parent_region.create_er_target(entrance.name) else: - # for 1-ways, the child region needs a target and coupling/naming is not a concern - target = child_region.create_er_target(child_region.name) + # for 1-ways, the child region needs a target. naming is not a concern for coupling so we + # allow it to be user provided (and require it, to prevent an unhelpful assumed name in pairings) + if not one_way_target_name: + raise ValueError("Cannot disconnect a one-way entrance without a target name specified") + target = child_region.create_er_target(one_way_target_name) target.randomization_type = entrance.randomization_type target.randomization_group = target_group or entrance.randomization_group diff --git a/test/general/test_entrance_rando.py b/test/general/test_entrance_rando.py index d2c1e168..542b3b4b 100644 --- a/test/general/test_entrance_rando.py +++ b/test/general/test_entrance_rando.py @@ -148,7 +148,7 @@ class TestDisconnectForRandomization(unittest.TestCase): e.randomization_group = 1 e.connect(r2) - disconnect_entrance_for_randomization(e) + disconnect_entrance_for_randomization(e, one_way_target_name="foo") self.assertIsNone(e.connected_region) self.assertEqual([], r1.entrances) @@ -158,10 +158,22 @@ class TestDisconnectForRandomization(unittest.TestCase): self.assertEqual(1, len(r2.entrances)) self.assertIsNone(r2.entrances[0].parent_region) - self.assertEqual("r2", r2.entrances[0].name) + self.assertEqual("foo", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) self.assertEqual(1, r2.entrances[0].randomization_group) + def test_disconnect_default_1way_no_vanilla_target_raises(self): + multiworld = generate_test_multiworld() + r1 = Region("r1", 1, multiworld) + r2 = Region("r2", 1, multiworld) + e = r1.create_exit("e") + e.randomization_type = EntranceType.ONE_WAY + e.randomization_group = 1 + e.connect(r2) + + with self.assertRaises(ValueError): + disconnect_entrance_for_randomization(e) + def test_disconnect_uses_alternate_group(self): multiworld = generate_test_multiworld() r1 = Region("r1", 1, multiworld) @@ -171,7 +183,7 @@ class TestDisconnectForRandomization(unittest.TestCase): e.randomization_group = 1 e.connect(r2) - disconnect_entrance_for_randomization(e, 2) + disconnect_entrance_for_randomization(e, 2, "foo") self.assertIsNone(e.connected_region) self.assertEqual([], r1.entrances) @@ -181,7 +193,7 @@ class TestDisconnectForRandomization(unittest.TestCase): self.assertEqual(1, len(r2.entrances)) self.assertIsNone(r2.entrances[0].parent_region) - self.assertEqual("r2", r2.entrances[0].name) + self.assertEqual("foo", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) self.assertEqual(2, r2.entrances[0].randomization_group)