Core: some typing and cleaning in BaseClasses.py (#3391)

* Core: some typing and cleaning in `BaseClasses.py`

* more backwards `__repr__`

* double-quote string

* remove some end-of-line whitespace
This commit is contained in:
Doug Hoskisson
2024-08-23 17:05:30 -07:00
committed by GitHub
parent 64b654d42e
commit 43cb9611fb

View File

@@ -11,8 +11,10 @@ from argparse import Namespace
from collections import Counter, deque from collections import Counter, deque
from collections.abc import Collection, MutableSequence from collections.abc import Collection, MutableSequence
from enum import IntEnum, IntFlag from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, NamedTuple, Optional, Set, Tuple, \ from typing import (AbstractSet, Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, NamedTuple,
TypedDict, Union, Type, ClassVar Optional, Protocol, Set, Tuple, Union, Type)
from typing_extensions import NotRequired, TypedDict
import NetUtils import NetUtils
import Options import Options
@@ -22,16 +24,16 @@ if typing.TYPE_CHECKING:
from worlds import AutoWorld from worlds import AutoWorld
class Group(TypedDict, total=False): class Group(TypedDict):
name: str name: str
game: str game: str
world: "AutoWorld.World" world: "AutoWorld.World"
players: Set[int] players: AbstractSet[int]
item_pool: Set[str] item_pool: NotRequired[Set[str]]
replacement_items: Dict[int, Optional[str]] replacement_items: NotRequired[Dict[int, Optional[str]]]
local_items: Set[str] local_items: NotRequired[Set[str]]
non_local_items: Set[str] non_local_items: NotRequired[Set[str]]
link_replacement: bool link_replacement: NotRequired[bool]
class ThreadBarrierProxy: class ThreadBarrierProxy:
@@ -48,6 +50,11 @@ class ThreadBarrierProxy:
"Please use multiworld.per_slot_randoms[player] or randomize ahead of output.") "Please use multiworld.per_slot_randoms[player] or randomize ahead of output.")
class HasNameAndPlayer(Protocol):
name: str
player: int
class MultiWorld(): class MultiWorld():
debug_types = False debug_types = False
player_name: Dict[int, str] player_name: Dict[int, str]
@@ -156,7 +163,7 @@ class MultiWorld():
self.start_inventory_from_pool: Dict[int, Options.StartInventoryPool] = {} self.start_inventory_from_pool: Dict[int, Options.StartInventoryPool] = {}
for player in range(1, players + 1): for player in range(1, players + 1):
def set_player_attr(attr, val): def set_player_attr(attr: str, val) -> None:
self.__dict__.setdefault(attr, {})[player] = val self.__dict__.setdefault(attr, {})[player] = val
set_player_attr('plando_items', []) set_player_attr('plando_items', [])
set_player_attr('plando_texts', {}) set_player_attr('plando_texts', {})
@@ -165,13 +172,13 @@ class MultiWorld():
set_player_attr('completion_condition', lambda state: True) set_player_attr('completion_condition', lambda state: True)
self.worlds = {} self.worlds = {}
self.per_slot_randoms = Utils.DeprecateDict("Using per_slot_randoms is now deprecated. Please use the " self.per_slot_randoms = Utils.DeprecateDict("Using per_slot_randoms is now deprecated. Please use the "
"world's random object instead (usually self.random)") "world's random object instead (usually self.random)")
self.plando_options = PlandoOptions.none self.plando_options = PlandoOptions.none
def get_all_ids(self) -> Tuple[int, ...]: def get_all_ids(self) -> Tuple[int, ...]:
return self.player_ids + tuple(self.groups) return self.player_ids + tuple(self.groups)
def add_group(self, name: str, game: str, players: Set[int] = frozenset()) -> Tuple[int, Group]: def add_group(self, name: str, game: str, players: AbstractSet[int] = frozenset()) -> Tuple[int, Group]:
"""Create a group with name and return the assigned player ID and group. """Create a group with name and return the assigned player ID and group.
If a group of this name already exists, the set of players is extended instead of creating a new one.""" If a group of this name already exists, the set of players is extended instead of creating a new one."""
from worlds import AutoWorld from worlds import AutoWorld
@@ -195,7 +202,7 @@ class MultiWorld():
return new_id, new_group return new_id, new_group
def get_player_groups(self, player) -> Set[int]: def get_player_groups(self, player: int) -> Set[int]:
return {group_id for group_id, group in self.groups.items() if player in group["players"]} return {group_id for group_id, group in self.groups.items() if player in group["players"]}
def set_seed(self, seed: Optional[int] = None, secure: bool = False, name: Optional[str] = None): def set_seed(self, seed: Optional[int] = None, secure: bool = False, name: Optional[str] = None):
@@ -258,7 +265,7 @@ class MultiWorld():
"link_replacement": replacement_prio.index(item_link["link_replacement"]), "link_replacement": replacement_prio.index(item_link["link_replacement"]),
} }
for name, item_link in item_links.items(): for _name, item_link in item_links.items():
current_item_name_groups = AutoWorld.AutoWorldRegister.world_types[item_link["game"]].item_name_groups current_item_name_groups = AutoWorld.AutoWorldRegister.world_types[item_link["game"]].item_name_groups
pool = set() pool = set()
local_items = set() local_items = set()
@@ -388,7 +395,7 @@ class MultiWorld():
return tuple(world for player, world in self.worlds.items() if return tuple(world for player, world in self.worlds.items() if
player not in self.groups and self.game[player] == game_name) player not in self.groups and self.game[player] == game_name)
def get_name_string_for_object(self, obj) -> str: def get_name_string_for_object(self, obj: HasNameAndPlayer) -> str:
return obj.name if self.players == 1 else f'{obj.name} ({self.get_player_name(obj.player)})' return obj.name if self.players == 1 else f'{obj.name} ({self.get_player_name(obj.player)})'
def get_player_name(self, player: int) -> str: def get_player_name(self, player: int) -> str:
@@ -439,7 +446,7 @@ class MultiWorld():
def get_items(self) -> List[Item]: def get_items(self) -> List[Item]:
return [loc.item for loc in self.get_filled_locations()] + self.itempool return [loc.item for loc in self.get_filled_locations()] + self.itempool
def find_item_locations(self, item, player: int, resolve_group_locations: bool = False) -> List[Location]: def find_item_locations(self, item: str, player: int, resolve_group_locations: bool = False) -> List[Location]:
if resolve_group_locations: if resolve_group_locations:
player_groups = self.get_player_groups(player) player_groups = self.get_player_groups(player)
return [location for location in self.get_locations() if return [location for location in self.get_locations() if
@@ -448,7 +455,7 @@ class MultiWorld():
return [location for location in self.get_locations() if return [location for location in self.get_locations() if
location.item and location.item.name == item and location.item.player == player] location.item and location.item.name == item and location.item.player == player]
def find_item(self, item, player: int) -> Location: def find_item(self, item: str, player: int) -> Location:
return next(location for location in self.get_locations() if return next(location for location in self.get_locations() if
location.item and location.item.name == item and location.item.player == player) location.item and location.item.name == item and location.item.player == player)
@@ -900,7 +907,7 @@ class Entrance:
addresses = None addresses = None
target = None target = None
def __init__(self, player: int, name: str = '', parent: Region = None): def __init__(self, player: int, name: str = "", parent: Optional[Region] = None) -> None:
self.name = name self.name = name
self.parent_region = parent self.parent_region = parent
self.player = player self.player = player
@@ -920,9 +927,6 @@ class Entrance:
region.entrances.append(self) region.entrances.append(self)
def __repr__(self): def __repr__(self):
return self.__str__()
def __str__(self):
multiworld = self.parent_region.multiworld if self.parent_region else None multiworld = self.parent_region.multiworld if self.parent_region else None
return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})' return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})'
@@ -1048,7 +1052,7 @@ class Region:
self.locations.append(location_type(self.player, location, address, self)) self.locations.append(location_type(self.player, location, address, self))
def connect(self, connecting_region: Region, name: Optional[str] = None, def connect(self, connecting_region: Region, name: Optional[str] = None,
rule: Optional[Callable[[CollectionState], bool]] = None) -> entrance_type: rule: Optional[Callable[[CollectionState], bool]] = None) -> Entrance:
""" """
Connects this Region to another Region, placing the provided rule on the connection. Connects this Region to another Region, placing the provided rule on the connection.
@@ -1088,9 +1092,6 @@ class Region:
rules[connecting_region] if rules and connecting_region in rules else None) rules[connecting_region] if rules and connecting_region in rules else None)
def __repr__(self): def __repr__(self):
return self.__str__()
def __str__(self):
return self.multiworld.get_name_string_for_object(self) if self.multiworld else f'{self.name} (Player {self.player})' return self.multiworld.get_name_string_for_object(self) if self.multiworld else f'{self.name} (Player {self.player})'
@@ -1109,9 +1110,9 @@ class Location:
locked: bool = False locked: bool = False
show_in_spoiler: bool = True show_in_spoiler: bool = True
progress_type: LocationProgressType = LocationProgressType.DEFAULT progress_type: LocationProgressType = LocationProgressType.DEFAULT
always_allow = staticmethod(lambda state, item: False) always_allow: Callable[[CollectionState, Item], bool] = staticmethod(lambda state, item: False)
access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True) access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True)
item_rule = staticmethod(lambda item: True) item_rule: Callable[[Item], bool] = staticmethod(lambda item: True)
item: Optional[Item] = None item: Optional[Item] = None
def __init__(self, player: int, name: str = '', address: Optional[int] = None, parent: Optional[Region] = None): def __init__(self, player: int, name: str = '', address: Optional[int] = None, parent: Optional[Region] = None):
@@ -1120,11 +1121,15 @@ class Location:
self.address = address self.address = address
self.parent_region = parent self.parent_region = parent
def can_fill(self, state: CollectionState, item: Item, check_access=True) -> bool: def can_fill(self, state: CollectionState, item: Item, check_access: bool = True) -> bool:
return ((self.always_allow(state, item) and item.name not in state.multiworld.worlds[item.player].options.non_local_items) return ((
or ((self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful)) self.always_allow(state, item)
and self.item_rule(item) and item.name not in state.multiworld.worlds[item.player].options.non_local_items
and (not check_access or self.can_reach(state)))) ) or (
(self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful))
and self.item_rule(item)
and (not check_access or self.can_reach(state))
))
def can_reach(self, state: CollectionState) -> bool: def can_reach(self, state: CollectionState) -> bool:
# Region.can_reach is just a cache lookup, so placing it first for faster abort on average # Region.can_reach is just a cache lookup, so placing it first for faster abort on average
@@ -1139,9 +1144,6 @@ class Location:
self.locked = True self.locked = True
def __repr__(self): def __repr__(self):
return self.__str__()
def __str__(self):
multiworld = self.parent_region.multiworld if self.parent_region and self.parent_region.multiworld else None multiworld = self.parent_region.multiworld if self.parent_region and self.parent_region.multiworld else None
return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})' return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})'
@@ -1163,7 +1165,7 @@ class Location:
@property @property
def native_item(self) -> bool: def native_item(self) -> bool:
"""Returns True if the item in this location matches game.""" """Returns True if the item in this location matches game."""
return self.item and self.item.game == self.game return self.item is not None and self.item.game == self.game
@property @property
def hint_text(self) -> str: def hint_text(self) -> str:
@@ -1246,9 +1248,6 @@ class Item:
return hash((self.name, self.player)) return hash((self.name, self.player))
def __repr__(self) -> str: def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
if self.location and self.location.parent_region and self.location.parent_region.multiworld: if self.location and self.location.parent_region and self.location.parent_region.multiworld:
return self.location.parent_region.multiworld.get_name_string_for_object(self) return self.location.parent_region.multiworld.get_name_string_for_object(self)
return f"{self.name} (Player {self.player})" return f"{self.name} (Player {self.player})"
@@ -1326,9 +1325,9 @@ class Spoiler:
# in the second phase, we cull each sphere such that the game is still beatable, # in the second phase, we cull each sphere such that the game is still beatable,
# reducing each range of influence to the bare minimum required inside it # reducing each range of influence to the bare minimum required inside it
restore_later = {} restore_later: Dict[Location, Item] = {}
for num, sphere in reversed(tuple(enumerate(collection_spheres))): for num, sphere in reversed(tuple(enumerate(collection_spheres))):
to_delete = set() to_delete: Set[Location] = set()
for location in sphere: for location in sphere:
# we remove the item at location and check if game is still beatable # we remove the item at location and check if game is still beatable
logging.debug('Checking if %s (Player %d) is required to beat the game.', location.item.name, logging.debug('Checking if %s (Player %d) is required to beat the game.', location.item.name,
@@ -1346,7 +1345,7 @@ class Spoiler:
sphere -= to_delete sphere -= to_delete
# second phase, sphere 0 # second phase, sphere 0
removed_precollected = [] removed_precollected: List[Item] = []
for item in (i for i in chain.from_iterable(multiworld.precollected_items.values()) if i.advancement): for item in (i for i in chain.from_iterable(multiworld.precollected_items.values()) if i.advancement):
logging.debug('Checking if %s (Player %d) is required to beat the game.', item.name, item.player) logging.debug('Checking if %s (Player %d) is required to beat the game.', item.name, item.player)
multiworld.precollected_items[item.player].remove(item) multiworld.precollected_items[item.player].remove(item)
@@ -1499,9 +1498,9 @@ class Spoiler:
if self.paths: if self.paths:
outfile.write('\n\nPaths:\n\n') outfile.write('\n\nPaths:\n\n')
path_listings = [] path_listings: List[str] = []
for location, path in sorted(self.paths.items()): for location, path in sorted(self.paths.items()):
path_lines = [] path_lines: List[str] = []
for region, exit in path: for region, exit in path:
if exit is not None: if exit is not None:
path_lines.append("{} -> {}".format(region, exit)) path_lines.append("{} -> {}".format(region, exit))