| 
									
										
										
										
											2025-05-13 03:58:03 -04:00
										 |  |  | import itertools | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import threading | 
					
						
							|  |  |  | import typing | 
					
						
							|  |  |  | import unittest | 
					
						
							|  |  |  | from contextlib import contextmanager | 
					
						
							|  |  |  | from typing import Optional, Dict, Union, Any, List, Iterable | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-24 01:15:41 -04:00
										 |  |  | from BaseClasses import get_seed, MultiWorld, Location, Item, CollectionState, Entrance | 
					
						
							| 
									
										
										
										
											2025-05-13 03:58:03 -04:00
										 |  |  | from test.bases import WorldTestBase | 
					
						
							|  |  |  | from test.general import gen_steps, setup_solo_multiworld as setup_base_solo_multiworld | 
					
						
							|  |  |  | from worlds.AutoWorld import call_all | 
					
						
							|  |  |  | from .assertion import RuleAssertMixin | 
					
						
							|  |  |  | from .options.utils import parse_class_option_keys, fill_namespace_with_default | 
					
						
							|  |  |  | from .. import StardewValleyWorld, StardewItem, StardewRule | 
					
						
							|  |  |  | from ..logic.time_logic import MONTH_COEFFICIENT | 
					
						
							|  |  |  | from ..options import StardewValleyOption, options | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | DEFAULT_TEST_SEED = get_seed() | 
					
						
							|  |  |  | logger.info(f"Default Test Seed: {DEFAULT_TEST_SEED}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def skip_default_tests() -> bool: | 
					
						
							|  |  |  |     return not bool(os.environ.get("base", False)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def skip_long_tests() -> bool: | 
					
						
							|  |  |  |     return not bool(os.environ.get("long", False)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SVTestCase(unittest.TestCase): | 
					
						
							|  |  |  |     skip_default_tests: bool = skip_default_tests() | 
					
						
							|  |  |  |     """Set False to not skip the base fill tests""" | 
					
						
							|  |  |  |     skip_long_tests: bool = skip_long_tests() | 
					
						
							|  |  |  |     """Set False to run tests that take long""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @contextmanager | 
					
						
							|  |  |  |     def solo_world_sub_test(self, msg: str | None = None, | 
					
						
							|  |  |  |                             /, | 
					
						
							|  |  |  |                             world_options: dict[str | type[StardewValleyOption], Any] | None = None, | 
					
						
							|  |  |  |                             *, | 
					
						
							|  |  |  |                             seed=DEFAULT_TEST_SEED, | 
					
						
							|  |  |  |                             world_caching=True, | 
					
						
							|  |  |  |                             **kwargs) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]: | 
					
						
							|  |  |  |         if msg is not None: | 
					
						
							|  |  |  |             msg += " " | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             msg = "" | 
					
						
							|  |  |  |         msg += f"[Seed = {seed}]" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with self.subTest(msg, **kwargs): | 
					
						
							|  |  |  |             with solo_multiworld(world_options, seed=seed, world_caching=world_caching) as (multiworld, world): | 
					
						
							|  |  |  |                 yield multiworld, world | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SVTestBase(RuleAssertMixin, WorldTestBase, SVTestCase): | 
					
						
							|  |  |  |     game = "Stardew Valley" | 
					
						
							|  |  |  |     world: StardewValleyWorld | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     seed = DEFAULT_TEST_SEED | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def setUpClass(cls) -> None: | 
					
						
							|  |  |  |         if cls is SVTestBase: | 
					
						
							|  |  |  |             raise unittest.SkipTest("No running tests on SVTestBase import.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         super().setUpClass() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def world_setup(self, *args, **kwargs): | 
					
						
							|  |  |  |         self.options = parse_class_option_keys(self.options) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.multiworld = setup_solo_multiworld(self.options, seed=self.seed) | 
					
						
							|  |  |  |         self.multiworld.lock.acquire() | 
					
						
							|  |  |  |         world = self.multiworld.worlds[self.player] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.original_state = self.multiworld.state.copy() | 
					
						
							|  |  |  |         self.original_itempool = self.multiworld.itempool.copy() | 
					
						
							|  |  |  |         self.unfilled_locations = self.multiworld.get_unfilled_locations(1) | 
					
						
							|  |  |  |         if self.constructed: | 
					
						
							|  |  |  |             self.world = world  # noqa | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def tearDown(self) -> None: | 
					
						
							|  |  |  |         self.multiworld.state = self.original_state | 
					
						
							|  |  |  |         self.multiworld.itempool = self.original_itempool | 
					
						
							|  |  |  |         for location in self.unfilled_locations: | 
					
						
							|  |  |  |             location.item = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.multiworld.lock.release() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def run_default_tests(self) -> bool: | 
					
						
							|  |  |  |         if self.skip_default_tests: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  |         return super().run_default_tests | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect_months(self, months: int) -> None: | 
					
						
							|  |  |  |         real_total_prog_items = self.world.total_progression_items | 
					
						
							|  |  |  |         percent = months * MONTH_COEFFICIENT | 
					
						
							|  |  |  |         self.collect("Stardrop", real_total_prog_items * 100 // percent) | 
					
						
							|  |  |  |         self.world.total_progression_items = real_total_prog_items | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect_lots_of_money(self, percent: float = 0.25): | 
					
						
							|  |  |  |         self.collect("Shipping Bin") | 
					
						
							|  |  |  |         real_total_prog_items = self.world.total_progression_items | 
					
						
							|  |  |  |         required_prog_items = int(round(real_total_prog_items * percent)) | 
					
						
							|  |  |  |         self.collect("Stardrop", required_prog_items) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect_all_the_money(self): | 
					
						
							|  |  |  |         self.collect_lots_of_money(0.95) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect_everything(self): | 
					
						
							|  |  |  |         non_event_items = [item for item in self.multiworld.get_items() if item.code] | 
					
						
							|  |  |  |         for item in non_event_items: | 
					
						
							|  |  |  |             self.multiworld.state.collect(item) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect_all_except(self, item_to_not_collect: str): | 
					
						
							|  |  |  |         non_event_items = [item for item in self.multiworld.get_items() if item.code] | 
					
						
							|  |  |  |         for item in non_event_items: | 
					
						
							|  |  |  |             if item.name != item_to_not_collect: | 
					
						
							|  |  |  |                 self.multiworld.state.collect(item) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_real_locations(self) -> List[Location]: | 
					
						
							|  |  |  |         return [location for location in self.multiworld.get_locations(self.player) if location.address is not None] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_real_location_names(self) -> List[str]: | 
					
						
							|  |  |  |         return [location.name for location in self.get_real_locations()] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def collect(self, item: Union[str, Item, Iterable[Item]], count: int = 1) -> Union[None, Item, List[Item]]: | 
					
						
							|  |  |  |         assert count > 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not isinstance(item, str): | 
					
						
							|  |  |  |             super().collect(item) | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if count == 1: | 
					
						
							|  |  |  |             item = self.create_item(item) | 
					
						
							|  |  |  |             self.multiworld.state.collect(item) | 
					
						
							|  |  |  |             return item | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         items = [] | 
					
						
							|  |  |  |         for i in range(count): | 
					
						
							|  |  |  |             item = self.create_item(item) | 
					
						
							|  |  |  |             self.multiworld.state.collect(item) | 
					
						
							|  |  |  |             items.append(item) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return items | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def create_item(self, item: str) -> StardewItem: | 
					
						
							|  |  |  |         return self.world.create_item(item) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_all_created_items(self) -> list[str]: | 
					
						
							|  |  |  |         return [item.name for item in itertools.chain(self.multiworld.get_items(), self.multiworld.precollected_items[self.player])] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remove_one_by_name(self, item: str) -> None: | 
					
						
							|  |  |  |         self.remove(self.create_item(item)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset_collection_state(self) -> None: | 
					
						
							|  |  |  |         self.multiworld.state = self.original_state.copy() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_rule_true(self, rule: StardewRule, state: CollectionState | None = None) -> None: | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             state = self.multiworld.state | 
					
						
							|  |  |  |         super().assert_rule_true(rule, state) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_rule_false(self, rule: StardewRule, state: CollectionState | None = None) -> None: | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             state = self.multiworld.state | 
					
						
							|  |  |  |         super().assert_rule_false(rule, state) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_can_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None: | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             state = self.multiworld.state | 
					
						
							|  |  |  |         super().assert_can_reach_location(location, state) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def assert_cannot_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None: | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             state = self.multiworld.state | 
					
						
							|  |  |  |         super().assert_cannot_reach_location(location, state) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-24 01:15:41 -04:00
										 |  |  |     def assert_can_reach_entrance(self, entrance: Entrance | str, state: CollectionState | None = None) -> None: | 
					
						
							|  |  |  |         if state is None: | 
					
						
							|  |  |  |             state = self.multiworld.state | 
					
						
							|  |  |  |         super().assert_can_reach_entrance(entrance, state) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-13 03:58:03 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | pre_generated_worlds = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @contextmanager | 
					
						
							|  |  |  | def solo_multiworld(world_options: dict[str | type[StardewValleyOption], Any] | None = None, | 
					
						
							|  |  |  |                     *, | 
					
						
							|  |  |  |                     seed=DEFAULT_TEST_SEED, | 
					
						
							|  |  |  |                     world_caching=True) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]: | 
					
						
							|  |  |  |     if not world_caching: | 
					
						
							|  |  |  |         multiworld = setup_solo_multiworld(world_options, seed, _cache={}) | 
					
						
							|  |  |  |         yield multiworld, typing.cast(StardewValleyWorld, multiworld.worlds[1]) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         multiworld = setup_solo_multiworld(world_options, seed) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             multiworld.lock.acquire() | 
					
						
							|  |  |  |             world = multiworld.worlds[1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             original_state = multiworld.state.copy() | 
					
						
							|  |  |  |             original_itempool = multiworld.itempool.copy() | 
					
						
							|  |  |  |             unfilled_locations = multiworld.get_unfilled_locations(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             yield multiworld, typing.cast(StardewValleyWorld, world) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             multiworld.state = original_state | 
					
						
							|  |  |  |             multiworld.itempool = original_itempool | 
					
						
							|  |  |  |             for location in unfilled_locations: | 
					
						
							|  |  |  |                 location.item = None | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             multiworld.lock.release() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Mostly a copy of test.general.setup_solo_multiworld, I just don't want to change the core. | 
					
						
							|  |  |  | def setup_solo_multiworld(test_options: Optional[Dict[Union[str, StardewValleyOption], str]] = None, | 
					
						
							|  |  |  |                           seed=DEFAULT_TEST_SEED, | 
					
						
							|  |  |  |                           _cache: Dict[frozenset, MultiWorld] = {},  # noqa | 
					
						
							|  |  |  |                           _steps=gen_steps) -> MultiWorld: | 
					
						
							|  |  |  |     test_options = parse_class_option_keys(test_options) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Yes I reuse the worlds generated between tests, its speeds the execution by a couple seconds | 
					
						
							|  |  |  |     # If the simple dict caching ends up taking too much memory, we could replace it with some kind of lru cache. | 
					
						
							|  |  |  |     should_cache = should_cache_world(test_options) | 
					
						
							|  |  |  |     if should_cache: | 
					
						
							|  |  |  |         frozen_options = make_hashable(test_options, seed) | 
					
						
							|  |  |  |         cached_multi_world = search_world_cache(_cache, frozen_options) | 
					
						
							|  |  |  |         if cached_multi_world: | 
					
						
							|  |  |  |             print(f"Using cached solo multi world [Seed = {cached_multi_world.seed}] [Cache size = {len(_cache)}]") | 
					
						
							|  |  |  |             return cached_multi_world | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     multiworld = setup_base_solo_multiworld(StardewValleyWorld, (), seed=seed) | 
					
						
							|  |  |  |     # print(f"Seed: {multiworld.seed}") # Uncomment to print the seed for every test | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     args = fill_namespace_with_default(test_options) | 
					
						
							|  |  |  |     multiworld.set_options(args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if "start_inventory" in test_options: | 
					
						
							|  |  |  |         for item, amount in test_options["start_inventory"].items(): | 
					
						
							|  |  |  |             for _ in range(amount): | 
					
						
							|  |  |  |                 multiworld.push_precollected(multiworld.create_item(item, 1)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for step in _steps: | 
					
						
							|  |  |  |         call_all(multiworld, step) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if should_cache: | 
					
						
							|  |  |  |         add_to_world_cache(_cache, frozen_options, multiworld)  # noqa | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Lock is needed for multi-threading tests | 
					
						
							|  |  |  |     setattr(multiworld, "lock", threading.Lock()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return multiworld | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def should_cache_world(test_options): | 
					
						
							|  |  |  |     if "start_inventory" in test_options: | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     trap_distribution_key = "trap_distribution" | 
					
						
							|  |  |  |     if trap_distribution_key not in test_options: | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     trap_distribution = test_options[trap_distribution_key] | 
					
						
							|  |  |  |     for key in trap_distribution: | 
					
						
							|  |  |  |         if trap_distribution[key] != options.TrapDistribution.default_weight: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def make_hashable(test_options, seed): | 
					
						
							|  |  |  |     return frozenset(test_options.items()).union({("seed", seed)}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def search_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset) -> Optional[MultiWorld]: | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return cache[frozen_options] | 
					
						
							|  |  |  |     except KeyError: | 
					
						
							|  |  |  |         for cached_options, multi_world in cache.items(): | 
					
						
							|  |  |  |             if frozen_options.issubset(cached_options): | 
					
						
							|  |  |  |                 return multi_world | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def add_to_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset, multi_world: MultiWorld) -> None: | 
					
						
							|  |  |  |     # We could complete the key with all the default options, but that does not seem to improve performances. | 
					
						
							|  |  |  |     cache[frozen_options] = multi_world | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def setup_multiworld(test_options: Iterable[Dict[str, int]] = None, seed=None) -> MultiWorld:  # noqa | 
					
						
							|  |  |  |     if test_options is None: | 
					
						
							|  |  |  |         test_options = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     multiworld = MultiWorld(len(test_options)) | 
					
						
							|  |  |  |     multiworld.player_name = {} | 
					
						
							|  |  |  |     multiworld.set_seed(seed) | 
					
						
							|  |  |  |     for i in range(1, len(test_options) + 1): | 
					
						
							|  |  |  |         multiworld.game[i] = StardewValleyWorld.game | 
					
						
							|  |  |  |         multiworld.player_name.update({i: f"Tester{i}"}) | 
					
						
							|  |  |  |     args = fill_namespace_with_default(test_options) | 
					
						
							|  |  |  |     multiworld.set_options(args) | 
					
						
							| 
									
										
										
										
											2025-05-22 08:27:18 -05:00
										 |  |  |     multiworld.state = CollectionState(multiworld) | 
					
						
							| 
									
										
										
										
											2025-05-13 03:58:03 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for step in gen_steps: | 
					
						
							|  |  |  |         call_all(multiworld, step) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return multiworld |