mirror of
https://github.com/MarioSpore/Grinch-AP.git
synced 2025-10-21 12:11:33 -06:00
307 lines
11 KiB
Python
307 lines
11 KiB
Python
import itertools
|
|
import logging
|
|
import os
|
|
import threading
|
|
import typing
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from typing import Optional, Dict, Union, Any, List, Iterable
|
|
|
|
from BaseClasses import get_seed, MultiWorld, Location, Item, CollectionState
|
|
from test.bases import WorldTestBase
|
|
from test.general import gen_steps, setup_solo_multiworld as setup_base_solo_multiworld
|
|
from worlds.AutoWorld import call_all
|
|
from .assertion import RuleAssertMixin
|
|
from .options.utils import parse_class_option_keys, fill_namespace_with_default
|
|
from .. import StardewValleyWorld, StardewItem, StardewRule
|
|
from ..logic.time_logic import MONTH_COEFFICIENT
|
|
from ..options import StardewValleyOption, options
|
|
|
|
logger = logging.getLogger(__name__)
|
|
DEFAULT_TEST_SEED = get_seed()
|
|
logger.info(f"Default Test Seed: {DEFAULT_TEST_SEED}")
|
|
|
|
|
|
def skip_default_tests() -> bool:
|
|
return not bool(os.environ.get("base", False))
|
|
|
|
|
|
def skip_long_tests() -> bool:
|
|
return not bool(os.environ.get("long", False))
|
|
|
|
|
|
class SVTestCase(unittest.TestCase):
|
|
skip_default_tests: bool = skip_default_tests()
|
|
"""Set False to not skip the base fill tests"""
|
|
skip_long_tests: bool = skip_long_tests()
|
|
"""Set False to run tests that take long"""
|
|
|
|
@contextmanager
|
|
def solo_world_sub_test(self, msg: str | None = None,
|
|
/,
|
|
world_options: dict[str | type[StardewValleyOption], Any] | None = None,
|
|
*,
|
|
seed=DEFAULT_TEST_SEED,
|
|
world_caching=True,
|
|
**kwargs) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]:
|
|
if msg is not None:
|
|
msg += " "
|
|
else:
|
|
msg = ""
|
|
msg += f"[Seed = {seed}]"
|
|
|
|
with self.subTest(msg, **kwargs):
|
|
with solo_multiworld(world_options, seed=seed, world_caching=world_caching) as (multiworld, world):
|
|
yield multiworld, world
|
|
|
|
|
|
class SVTestBase(RuleAssertMixin, WorldTestBase, SVTestCase):
|
|
game = "Stardew Valley"
|
|
world: StardewValleyWorld
|
|
|
|
seed = DEFAULT_TEST_SEED
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
if cls is SVTestBase:
|
|
raise unittest.SkipTest("No running tests on SVTestBase import.")
|
|
|
|
super().setUpClass()
|
|
|
|
def world_setup(self, *args, **kwargs):
|
|
self.options = parse_class_option_keys(self.options)
|
|
|
|
self.multiworld = setup_solo_multiworld(self.options, seed=self.seed)
|
|
self.multiworld.lock.acquire()
|
|
world = self.multiworld.worlds[self.player]
|
|
|
|
self.original_state = self.multiworld.state.copy()
|
|
self.original_itempool = self.multiworld.itempool.copy()
|
|
self.unfilled_locations = self.multiworld.get_unfilled_locations(1)
|
|
if self.constructed:
|
|
self.world = world # noqa
|
|
|
|
def tearDown(self) -> None:
|
|
self.multiworld.state = self.original_state
|
|
self.multiworld.itempool = self.original_itempool
|
|
for location in self.unfilled_locations:
|
|
location.item = None
|
|
|
|
self.multiworld.lock.release()
|
|
|
|
@property
|
|
def run_default_tests(self) -> bool:
|
|
if self.skip_default_tests:
|
|
return False
|
|
return super().run_default_tests
|
|
|
|
def collect_months(self, months: int) -> None:
|
|
real_total_prog_items = self.world.total_progression_items
|
|
percent = months * MONTH_COEFFICIENT
|
|
self.collect("Stardrop", real_total_prog_items * 100 // percent)
|
|
self.world.total_progression_items = real_total_prog_items
|
|
|
|
def collect_lots_of_money(self, percent: float = 0.25):
|
|
self.collect("Shipping Bin")
|
|
real_total_prog_items = self.world.total_progression_items
|
|
required_prog_items = int(round(real_total_prog_items * percent))
|
|
self.collect("Stardrop", required_prog_items)
|
|
|
|
def collect_all_the_money(self):
|
|
self.collect_lots_of_money(0.95)
|
|
|
|
def collect_everything(self):
|
|
non_event_items = [item for item in self.multiworld.get_items() if item.code]
|
|
for item in non_event_items:
|
|
self.multiworld.state.collect(item)
|
|
|
|
def collect_all_except(self, item_to_not_collect: str):
|
|
non_event_items = [item for item in self.multiworld.get_items() if item.code]
|
|
for item in non_event_items:
|
|
if item.name != item_to_not_collect:
|
|
self.multiworld.state.collect(item)
|
|
|
|
def get_real_locations(self) -> List[Location]:
|
|
return [location for location in self.multiworld.get_locations(self.player) if location.address is not None]
|
|
|
|
def get_real_location_names(self) -> List[str]:
|
|
return [location.name for location in self.get_real_locations()]
|
|
|
|
def collect(self, item: Union[str, Item, Iterable[Item]], count: int = 1) -> Union[None, Item, List[Item]]:
|
|
assert count > 0
|
|
|
|
if not isinstance(item, str):
|
|
super().collect(item)
|
|
return
|
|
|
|
if count == 1:
|
|
item = self.create_item(item)
|
|
self.multiworld.state.collect(item)
|
|
return item
|
|
|
|
items = []
|
|
for i in range(count):
|
|
item = self.create_item(item)
|
|
self.multiworld.state.collect(item)
|
|
items.append(item)
|
|
|
|
return items
|
|
|
|
def create_item(self, item: str) -> StardewItem:
|
|
return self.world.create_item(item)
|
|
|
|
def get_all_created_items(self) -> list[str]:
|
|
return [item.name for item in itertools.chain(self.multiworld.get_items(), self.multiworld.precollected_items[self.player])]
|
|
|
|
def remove_one_by_name(self, item: str) -> None:
|
|
self.remove(self.create_item(item))
|
|
|
|
def reset_collection_state(self) -> None:
|
|
self.multiworld.state = self.original_state.copy()
|
|
|
|
def assert_rule_true(self, rule: StardewRule, state: CollectionState | None = None) -> None:
|
|
if state is None:
|
|
state = self.multiworld.state
|
|
super().assert_rule_true(rule, state)
|
|
|
|
def assert_rule_false(self, rule: StardewRule, state: CollectionState | None = None) -> None:
|
|
if state is None:
|
|
state = self.multiworld.state
|
|
super().assert_rule_false(rule, state)
|
|
|
|
def assert_can_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None:
|
|
if state is None:
|
|
state = self.multiworld.state
|
|
super().assert_can_reach_location(location, state)
|
|
|
|
def assert_cannot_reach_location(self, location: Location | str, state: CollectionState | None = None) -> None:
|
|
if state is None:
|
|
state = self.multiworld.state
|
|
super().assert_cannot_reach_location(location, state)
|
|
|
|
|
|
pre_generated_worlds = {}
|
|
|
|
|
|
@contextmanager
|
|
def solo_multiworld(world_options: dict[str | type[StardewValleyOption], Any] | None = None,
|
|
*,
|
|
seed=DEFAULT_TEST_SEED,
|
|
world_caching=True) -> Iterable[tuple[MultiWorld, StardewValleyWorld]]:
|
|
if not world_caching:
|
|
multiworld = setup_solo_multiworld(world_options, seed, _cache={})
|
|
yield multiworld, typing.cast(StardewValleyWorld, multiworld.worlds[1])
|
|
else:
|
|
multiworld = setup_solo_multiworld(world_options, seed)
|
|
try:
|
|
multiworld.lock.acquire()
|
|
world = multiworld.worlds[1]
|
|
|
|
original_state = multiworld.state.copy()
|
|
original_itempool = multiworld.itempool.copy()
|
|
unfilled_locations = multiworld.get_unfilled_locations(1)
|
|
|
|
yield multiworld, typing.cast(StardewValleyWorld, world)
|
|
|
|
multiworld.state = original_state
|
|
multiworld.itempool = original_itempool
|
|
for location in unfilled_locations:
|
|
location.item = None
|
|
finally:
|
|
multiworld.lock.release()
|
|
|
|
|
|
# Mostly a copy of test.general.setup_solo_multiworld, I just don't want to change the core.
|
|
def setup_solo_multiworld(test_options: Optional[Dict[Union[str, StardewValleyOption], str]] = None,
|
|
seed=DEFAULT_TEST_SEED,
|
|
_cache: Dict[frozenset, MultiWorld] = {}, # noqa
|
|
_steps=gen_steps) -> MultiWorld:
|
|
test_options = parse_class_option_keys(test_options)
|
|
|
|
# Yes I reuse the worlds generated between tests, its speeds the execution by a couple seconds
|
|
# If the simple dict caching ends up taking too much memory, we could replace it with some kind of lru cache.
|
|
should_cache = should_cache_world(test_options)
|
|
if should_cache:
|
|
frozen_options = make_hashable(test_options, seed)
|
|
cached_multi_world = search_world_cache(_cache, frozen_options)
|
|
if cached_multi_world:
|
|
print(f"Using cached solo multi world [Seed = {cached_multi_world.seed}] [Cache size = {len(_cache)}]")
|
|
return cached_multi_world
|
|
|
|
multiworld = setup_base_solo_multiworld(StardewValleyWorld, (), seed=seed)
|
|
# print(f"Seed: {multiworld.seed}") # Uncomment to print the seed for every test
|
|
|
|
args = fill_namespace_with_default(test_options)
|
|
multiworld.set_options(args)
|
|
|
|
if "start_inventory" in test_options:
|
|
for item, amount in test_options["start_inventory"].items():
|
|
for _ in range(amount):
|
|
multiworld.push_precollected(multiworld.create_item(item, 1))
|
|
|
|
for step in _steps:
|
|
call_all(multiworld, step)
|
|
|
|
if should_cache:
|
|
add_to_world_cache(_cache, frozen_options, multiworld) # noqa
|
|
|
|
# Lock is needed for multi-threading tests
|
|
setattr(multiworld, "lock", threading.Lock())
|
|
|
|
return multiworld
|
|
|
|
|
|
def should_cache_world(test_options):
|
|
if "start_inventory" in test_options:
|
|
return False
|
|
|
|
trap_distribution_key = "trap_distribution"
|
|
if trap_distribution_key not in test_options:
|
|
return True
|
|
|
|
trap_distribution = test_options[trap_distribution_key]
|
|
for key in trap_distribution:
|
|
if trap_distribution[key] != options.TrapDistribution.default_weight:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def make_hashable(test_options, seed):
|
|
return frozenset(test_options.items()).union({("seed", seed)})
|
|
|
|
|
|
def search_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset) -> Optional[MultiWorld]:
|
|
try:
|
|
return cache[frozen_options]
|
|
except KeyError:
|
|
for cached_options, multi_world in cache.items():
|
|
if frozen_options.issubset(cached_options):
|
|
return multi_world
|
|
return None
|
|
|
|
|
|
def add_to_world_cache(cache: Dict[frozenset, MultiWorld], frozen_options: frozenset, multi_world: MultiWorld) -> None:
|
|
# We could complete the key with all the default options, but that does not seem to improve performances.
|
|
cache[frozen_options] = multi_world
|
|
|
|
|
|
def setup_multiworld(test_options: Iterable[Dict[str, int]] = None, seed=None) -> MultiWorld: # noqa
|
|
if test_options is None:
|
|
test_options = []
|
|
|
|
multiworld = MultiWorld(len(test_options))
|
|
multiworld.player_name = {}
|
|
multiworld.set_seed(seed)
|
|
for i in range(1, len(test_options) + 1):
|
|
multiworld.game[i] = StardewValleyWorld.game
|
|
multiworld.player_name.update({i: f"Tester{i}"})
|
|
args = fill_namespace_with_default(test_options)
|
|
multiworld.set_options(args)
|
|
multiworld.state = CollectionState(multiworld)
|
|
|
|
for step in gen_steps:
|
|
call_all(multiworld, step)
|
|
|
|
return multiworld
|