Files
Grinch-AP/worlds/stardew_valley/test/bases.py

312 lines
11 KiB
Python
Raw Normal View History

import itertools
import logging
import os
import threading
import typing
import unittest
from contextlib import contextmanager
from typing import Optional, Dict, Union, Any, List, Iterable
from BaseClasses import get_seed, MultiWorld, Location, Item, CollectionState, Entrance
from test.bases import WorldTestBase
from test.general import gen_steps, setup_solo_multiworld as setup_base_solo_multiworld
from worlds.AutoWorld import call_all
from .assertion import RuleAssertMixin
from .options.utils import parse_class_option_keys, fill_namespace_with_default
from .. import StardewValleyWorld, StardewItem, StardewRule
from ..logic.time_logic import MONTH_COEFFICIENT
from ..options import StardewValleyOption, options
logger = logging.getLogger(__name__)
DEFAULT_TEST_SEED = get_seed()
logger.info(f"Default Test Seed: {DEFAULT_TEST_SEED}")
def skip_default_tests() -> bool:
return not bool(os.environ.get("base", False))
def skip_long_tests() -> bool:
return not bool(os.environ.get("long", False))
class SVTestCase(unittest.TestCase):
skip_default_tests: bool = skip_default_tests()
"""Set False to not skip the base fill tests"""
skip_long_tests: bool = skip_long_tests()
"""Set False to run tests that take long"""
@contextmanager
def solo_world_sub_test(self, msg: str | None = None,
/,
world_options: dict[str | type[StardewValleyOption], Any] | None = None,
*,
seed=DEFAULT_TEST_SEED,
world_caching=True,
**kwargs) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]:
if msg is not None:
msg += " "
else:
msg = ""
msg += f"[Seed = {seed}]"
with self.subTest(msg, **kwargs):
with solo_multiworld(world_options, seed=seed, world_caching=world_caching) as (multiworld, world):
yield multiworld, world
class SVTestBase(RuleAssertMixin, WorldTestBase, SVTestCase):
game = "Stardew Valley"
world: StardewValleyWorld
seed = DEFAULT_TEST_SEED
@classmethod
def setUpClass(cls) -> None:
if cls is SVTestBase:
raise unittest.SkipTest("No running tests on SVTestBase import.")
super().setUpClass()
def world_setup(self, *args, **kwargs):
self.options = parse_class_option_keys(self.options)
self.multiworld = setup_solo_multiworld(self.options, seed=self.seed)
self.multiworld.lock.acquire()
world = self.multiworld.worlds[self.player]
self.original_state = self.multiworld.state.copy()
self.original_itempool = self.multiworld.itempool.copy()
self.unfilled_locations = self.multiworld.get_unfilled_locations(1)
if self.constructed:
self.world = world # noqa
def tearDown(self) -> None:
self.multiworld.state = self.original_state
self.multiworld.itempool = self.original_itempool
for location in self.unfilled_locations:
location.item = None
self.multiworld.lock.release()
@property
def run_default_tests(self) -> bool:
if self.skip_default_tests:
return False
return super().run_default_tests
def collect_months(self, months: int) -> None:
real_total_prog_items = self.world.total_progression_items
percent = months * MONTH_COEFFICIENT
self.collect("Stardrop", real_total_prog_items * 100 // percent)
self.world.total_progression_items = real_total_prog_items
def collect_lots_of_money(self, percent: float = 0.25):
self.collect("Shipping Bin")
real_total_prog_items = self.world.total_progression_items
required_prog_items = int(round(real_total_prog_items * percent))
self.collect("Stardrop", required_prog_items)
def collect_all_the_money(self):
self.collect_lots_of_money(0.95)
def collect_everything(self):
non_event_items = [item for item in self.multiworld.get_items() if item.code]
for item in non_event_items:
self.multiworld.state.collect(item)
def collect_all_except(self, item_to_not_collect: str):
non_event_items = [item for item in self.multiworld.get_items() if item.code]
for item in non_event_items:
if item.name != item_to_not_collect:
self.multiworld.state.collect(item)
def get_real_locations(self) -> List[Location]:
return [location for location in self.multiworld.get_locations(self.player) if location.address is not None]
def get_real_location_names(self) -> List[str]:
return [location.name for location in self.get_real_locations()]
def collect(self, item: Union[str, Item, Iterable[Item]], count: int = 1) -> Union[None, Item, List[Item]]:
assert count > 0
if not isinstance(item, str):
super().collect(item)
return
if count == 1:
item = self.create_item(item)
self.multiworld.state.collect(item)
return item
items = []
for i in range(count):
item = self.create_item(item)
self.multiworld.state.collect(item)
items.append(item)
return items
def create_item(self, item: str) -> StardewItem:
return self.world.create_item(item)
def get_all_created_items(self) -> list[str]:
return [item.name for item in itertools.chain(self.multiworld.get_items(), self.multiworld.precollected_items[self.player])]
def remove_one_by_name(self, item: str) -> None:
self.remove(self.create_item(item))
def reset_collection_state(self) -> None:
self.multiworld.state = self.original_state.copy()
def assert_rule_true(self, rule: StardewRule, state: CollectionState | None = None) -> None:
if state is None:
state = self.multiworld.state
super().assert_rule_true(rule, state)
def assert_rule_false(self, rule: StardewRule, state: CollectionState | None = None) -> None:
if state is None:
state = self.multiworld.state
super().assert_rule_false(rule, state)
def assert_can_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None:
if state is None:
state = self.multiworld.state
super().assert_can_reach_location(location, state)
def assert_cannot_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None:
if state is None:
state = self.multiworld.state
super().assert_cannot_reach_location(location, state)
def assert_can_reach_entrance(self, entrance: Entrance | str, state: CollectionState | None = None) -> None:
if state is None:
state = self.multiworld.state
super().assert_can_reach_entrance(entrance, state)
pre_generated_worlds = {}
@contextmanager
def solo_multiworld(world_options: dict[str | type[StardewValleyOption], Any] | None = None,
*,
seed=DEFAULT_TEST_SEED,
world_caching=True) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]:
if not world_caching:
multiworld = setup_solo_multiworld(world_options, seed, _cache={})
yield multiworld, typing.cast(StardewValleyWorld, multiworld.worlds[1])
else:
multiworld = setup_solo_multiworld(world_options, seed)
try:
multiworld.lock.acquire()
world = multiworld.worlds[1]
original_state = multiworld.state.copy()
original_itempool = multiworld.itempool.copy()
unfilled_locations = multiworld.get_unfilled_locations(1)
yield multiworld, typing.cast(StardewValleyWorld, world)
multiworld.state = original_state
multiworld.itempool = original_itempool
for location in unfilled_locations:
location.item = None
finally:
multiworld.lock.release()
# Mostly a copy of test.general.setup_solo_multiworld, I just don't want to change the core.
def setup_solo_multiworld(test_options: Optional[Dict[Union[str, StardewValleyOption], str]] = None,
seed=DEFAULT_TEST_SEED,
_cache: Dict[frozenset, MultiWorld] = {}, # noqa
_steps=gen_steps) -> MultiWorld:
test_options = parse_class_option_keys(test_options)
# Yes I reuse the worlds generated between tests, its speeds the execution by a couple seconds
# If the simple dict caching ends up taking too much memory, we could replace it with some kind of lru cache.
should_cache = should_cache_world(test_options)
if should_cache:
frozen_options = make_hashable(test_options, seed)
cached_multi_world = search_world_cache(_cache, frozen_options)
if cached_multi_world:
print(f"Using cached solo multi world [Seed = {cached_multi_world.seed}] [Cache size = {len(_cache)}]")
return cached_multi_world
multiworld = setup_base_solo_multiworld(StardewValleyWorld, (), seed=seed)
# print(f"Seed: {multiworld.seed}") # Uncomment to print the seed for every test
args = fill_namespace_with_default(test_options)
multiworld.set_options(args)
if "start_inventory" in test_options:
for item, amount in test_options["start_inventory"].items():
for _ in range(amount):
multiworld.push_precollected(multiworld.create_item(item, 1))
for step in _steps:
call_all(multiworld, step)
if should_cache:
add_to_world_cache(_cache, frozen_options, multiworld) # noqa
# Lock is needed for multi-threading tests
setattr(multiworld, "lock", threading.Lock())
return multiworld
def should_cache_world(test_options):
if "start_inventory" in test_options:
return False
trap_distribution_key = "trap_distribution"
if trap_distribution_key not in test_options:
return True
trap_distribution = test_options[trap_distribution_key]
for key in trap_distribution:
if trap_distribution[key] != options.TrapDistribution.default_weight:
return False
return True
def make_hashable(test_options, seed):
return frozenset(test_options.items()).union({("seed", seed)})
def search_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset) -> Optional[MultiWorld]:
try:
return cache[frozen_options]
except KeyError:
for cached_options, multi_world in cache.items():
if frozen_options.issubset(cached_options):
return multi_world
return None
def add_to_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset, multi_world: MultiWorld) -> None:
# We could complete the key with all the default options, but that does not seem to improve performances.
cache[frozen_options] = multi_world
def setup_multiworld(test_options: Iterable[Dict[str, int]] = None, seed=None) -> MultiWorld: # noqa
if test_options is None:
test_options = []
multiworld = MultiWorld(len(test_options))
multiworld.player_name = {}
multiworld.set_seed(seed)
for i in range(1, len(test_options) + 1):
multiworld.game[i] = StardewValleyWorld.game
multiworld.player_name.update({i: f"Tester{i}"})
args = fill_namespace_with_default(test_options)
multiworld.set_options(args)
multiworld.state = CollectionState(multiworld)
for step in gen_steps:
call_all(multiworld, step)
return multiworld