diff --git a/BaseClasses.py b/BaseClasses.py index 8fe108a7..f2aea389 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -836,20 +836,23 @@ class Region: for location, address in locations.items(): self.locations.append(location_type(self.player, location, address, self)) - def add_exits(self, exits: Dict[str, Optional[str]], rules: Dict[str, Callable[[CollectionState], bool]] = None) -> None: + def add_exits(self, exits: Union[Iterable[str], Dict[str, Optional[str]]], + rules: Dict[str, Callable[[CollectionState], bool]] = None) -> None: """ Connects current region to regions in exit dictionary. Passed region names must exist first. - :param exits: exits from the region. format is {"connecting_region", "exit_name"} + :param exits: exits from the region. format is {"connecting_region": "exit_name"}. if a non dict is provided, + created entrances will be named "self.name -> connecting_region" :param rules: rules for the exits from this region. format is {"connecting_region", rule} """ - for exiting_region, name in exits.items(): - ret = Entrance(self.player, name, self) if name \ - else Entrance(self.player, f"{self.name} -> {exiting_region}", self) - if rules and exiting_region in rules: - ret.access_rule = rules[exiting_region] - self.exits.append(ret) - ret.connect(self.multiworld.get_region(exiting_region, self.player)) + if not isinstance(exits, Dict): + exits = dict.fromkeys(exits) + for connecting_region, name in exits.items(): + entrance = Entrance(self.player, name if name else f"{self.name} -> {connecting_region}", self) + if rules and connecting_region in rules: + entrance.access_rule = rules[connecting_region] + self.exits.append(entrance) + entrance.connect(self.multiworld.get_region(connecting_region, self.player)) def __repr__(self): return self.__str__() diff --git a/test/general/TestHelpers.py b/test/general/TestHelpers.py index b6b1ea47..c0b560c7 100644 --- a/test/general/TestHelpers.py +++ b/test/general/TestHelpers.py @@ -19,6 +19,7 @@ class TestHelpers(unittest.TestCase): regions: Dict[str, str] = { "TestRegion1": "I'm an apple", "TestRegion2": "I'm a banana", + "TestRegion3": "Empty Region", } locations: Dict[str, Dict[str, Optional[int]]] = { @@ -38,6 +39,10 @@ class TestHelpers(unittest.TestCase): "TestRegion2": {"TestRegion1": None}, } + reg_exit_set: Dict[str, set[str]] = { + "TestRegion1": {"TestRegion3"} + } + exit_rules: Dict[str, Callable[[CollectionState], bool]] = { "TestRegion1": lambda state: state.has("test_item", self.player) } @@ -68,3 +73,10 @@ class TestHelpers(unittest.TestCase): entrance_name = exit_name if exit_name else f"{parent} -> {exit_reg}" self.assertEqual(exit_rules[exit_reg], self.multiworld.get_entrance(entrance_name, self.player).access_rule) + + for region in reg_exit_set: + current_region = self.multiworld.get_region(region, self.player) + current_region.add_exits(reg_exit_set[region]) + exit_names = {_exit.name for _exit in current_region.exits} + for reg_exit in reg_exit_set[region]: + self.assertTrue(f"{region} -> {reg_exit}" in exit_names, f"{region} -> {reg_exit} not in {exit_names}")