From f33babc4206f6dffb4ec41e7a73057ba8dbdbc8c Mon Sep 17 00:00:00 2001 From: Aaron Wagener Date: Sat, 30 Sep 2023 04:53:11 -0500 Subject: [PATCH] Tests: add a name removal method (#2233) * Tests: add a name removal method, and have assertAccessDependency use and dispose its own state * Update test/TestBase.py --------- Co-authored-by: black-sliver <59490463+black-sliver@users.noreply.github.com> --- test/TestBase.py | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/test/TestBase.py b/test/TestBase.py index 1f0853ef..856428fb 100644 --- a/test/TestBase.py +++ b/test/TestBase.py @@ -141,13 +141,16 @@ class WorldTestBase(unittest.TestCase): call_all(self.multiworld, step) # methods that can be called within tests - def collect_all_but(self, item_names: typing.Union[str, typing.Iterable[str]]) -> None: + def collect_all_but(self, item_names: typing.Union[str, typing.Iterable[str]], + state: typing.Optional[CollectionState] = None) -> None: """Collects all pre-placed items and items in the multiworld itempool except those provided""" if isinstance(item_names, str): item_names = (item_names,) + if not state: + state = self.multiworld.state for item in self.multiworld.get_items(): if item.name not in item_names: - self.multiworld.state.collect(item) + state.collect(item) def get_item_by_name(self, item_name: str) -> Item: """Returns the first item found in placed items, or in the itempool with the matching name""" @@ -174,6 +177,12 @@ class WorldTestBase(unittest.TestCase): items = (items,) for item in items: self.multiworld.state.collect(item) + + def remove_by_name(self, item_names: typing.Union[str, typing.Iterable[str]]) -> typing.List[Item]: + """Remove all of the items in the item pool with the given names from state""" + items = self.get_items_by_name(item_names) + self.remove(items) + return items def remove(self, items: typing.Union[Item, typing.Iterable[Item]]) -> None: """Removes the provided item(s) from state""" @@ -198,23 +207,32 @@ class WorldTestBase(unittest.TestCase): def assertAccessDependency(self, locations: typing.List[str], - possible_items: typing.Iterable[typing.Iterable[str]]) -> None: + possible_items: typing.Iterable[typing.Iterable[str]], + only_check_listed: bool = False) -> None: """Asserts that the provided locations can't be reached without the listed items but can be reached with any one of the provided combinations""" all_items = [item_name for item_names in possible_items for item_name in item_names] - self.collect_all_but(all_items) - for location in self.multiworld.get_locations(): - loc_reachable = self.multiworld.state.can_reach(location) - self.assertEqual(loc_reachable, location.name not in locations, - f"{location.name} is reachable without {all_items}" if loc_reachable - else f"{location.name} is not reachable without {all_items}") - for item_names in possible_items: - items = self.collect_by_name(item_names) + state = CollectionState(self.multiworld) + self.collect_all_but(all_items, state) + if only_check_listed: for location in locations: - self.assertTrue(self.can_reach_location(location), + self.assertFalse(state.can_reach(location, "Location", 1), f"{location} is reachable without {all_items}") + else: + for location in self.multiworld.get_locations(): + loc_reachable = state.can_reach(location, "Location", 1) + self.assertEqual(loc_reachable, location.name not in locations, + f"{location.name} is reachable without {all_items}" if loc_reachable + else f"{location.name} is not reachable without {all_items}") + for item_names in possible_items: + items = self.get_items_by_name(item_names) + for item in items: + state.collect(item) + for location in locations: + self.assertTrue(state.can_reach(location, "Location", 1), f"{location} not reachable with {item_names}") - self.remove(items) + for item in items: + state.remove(item) def assertBeatable(self, beatable: bool): """Asserts that the game can be beaten with the current state"""