diff --git a/BaseClasses.py b/BaseClasses.py index a60951b2..e5d92b95 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -7,6 +7,7 @@ import json import functools from collections import OrderedDict, Counter, deque from typing import List, Dict, Optional, Set, Iterable, Union, Any, Tuple, TypedDict, Callable +import typing # this can go away when Python 3.8 support is dropped import secrets import random @@ -563,9 +564,20 @@ class MultiWorld(): return False +PathValue = Tuple[str, Optional["PathValue"]] + + class CollectionState(): - additional_init_functions: List[Callable] = [] - additional_copy_functions: List[Callable] = [] + prog_items: typing.Counter[Tuple[str, int]] + world: MultiWorld + reachable_regions: Dict[int, Set[Region]] + blocked_connections: Dict[int, Set[Entrance]] + events: Set[Location] + path: Dict[Union[Region, Entrance], PathValue] + locations_checked: Set[Location] + stale: Dict[int, bool] + additional_init_functions: List[Callable[[CollectionState, MultiWorld], None]] = [] + additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = [] def __init__(self, parent: MultiWorld): self.prog_items = Counter() @@ -603,6 +615,7 @@ class CollectionState(): if new_region in rrp: bc.remove(connection) elif connection.can_reach(self): + assert new_region, "tried to search through an Entrance with no Region" rrp.add(new_region) bc.remove(connection) bc.update(new_region.exits) @@ -633,7 +646,8 @@ class CollectionState(): spot: Union[Location, Entrance, Region, str], resolution_hint: Optional[str] = None, player: Optional[int] = None) -> bool: - if not hasattr(spot, "can_reach"): + if isinstance(spot, str): + assert isinstance(player, int), "can_reach: player is required if spot is str" # try to resolve a name if resolution_hint == 'Location': spot = self.world.get_location(spot, player) @@ -644,7 +658,7 @@ class CollectionState(): spot = self.world.get_region(spot, player) return spot.can_reach(self) - def sweep_for_events(self, key_only: bool = False, locations: Set[Location] = None): + def sweep_for_events(self, key_only: bool = False, locations: Optional[Iterable[Location]] = None) -> None: if locations is None: locations = self.world.get_filled_locations() new_locations = True @@ -656,6 +670,7 @@ class CollectionState(): new_locations = reachable_events - self.events for event in new_locations: self.events.add(event) + assert isinstance(event.item, Item), "tried to collect Event with no Item" self.collect(event.item, True, event) def has(self, item: str, player: int, count: int = 1) -> bool: @@ -670,7 +685,7 @@ class CollectionState(): def count(self, item: str, player: int) -> int: return self.prog_items[item, player] - def has_group(self, item_name_group: str, player: int, count: int = 1): + def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool: found: int = 0 for item_name in self.world.worlds[player].item_name_groups[item_name_group]: found += self.prog_items[item_name, player] @@ -678,7 +693,7 @@ class CollectionState(): return True return False - def count_group(self, item_name_group: str, player: int): + def count_group(self, item_name_group: str, player: int) -> int: found: int = 0 for item_name in self.world.worlds[player].item_name_groups[item_name_group]: found += self.prog_items[item_name, player]