Locality: rewrite for linear memory consumption, from quadratic (#1091)

This commit is contained in:
Fabian Dill
2022-10-17 03:22:02 +02:00
committed by GitHub
parent bb46ee7fc1
commit b533ffb9e8
3 changed files with 63 additions and 22 deletions

View File

@@ -1,3 +1,4 @@
import collections
import typing
from BaseClasses import LocationProgressType, MultiWorld
@@ -12,29 +13,72 @@ else:
ItemRule = typing.Callable[[object], bool]
def group_locality_rules(world):
def locality_needed(world: MultiWorld) -> bool:
for player in world.player_ids:
if world.local_items[player].value:
return True
if world.non_local_items[player].value:
return True
# Group
for group_id, group in world.groups.items():
if set(world.player_ids) == set(group["players"]):
continue
if group["local_items"]:
for location in world.get_locations():
if location.player not in group["players"]:
forbid_items_for_player(location, group["local_items"], group_id)
return True
if group["non_local_items"]:
for location in world.get_locations():
if location.player in group["players"]:
forbid_items_for_player(location, group["non_local_items"], group_id)
return True
def locality_rules(world, player: int):
if world.local_items[player].value:
def locality_rules(world: MultiWorld):
if locality_needed(world):
forbid_data: typing.Dict[int, typing.Dict[int, typing.Set[str]]] = \
collections.defaultdict(lambda: collections.defaultdict(set))
def forbid(sender: int, receiver: int, items: typing.Set[str]):
forbid_data[sender][receiver].update(items)
for receiving_player in world.player_ids:
local_items: typing.Set[str] = world.local_items[receiving_player].value
if local_items:
for sending_player in world.player_ids:
if receiving_player != sending_player:
forbid(sending_player, receiving_player, local_items)
non_local_items: typing.Set[str] = world.non_local_items[receiving_player].value
if non_local_items:
forbid(receiving_player, receiving_player, non_local_items)
# Group
for receiving_group_id, receiving_group in world.groups.items():
if set(world.player_ids) == set(receiving_group["players"]):
continue
if receiving_group["local_items"]:
for sending_player in world.player_ids:
if sending_player not in receiving_group["players"]:
forbid(sending_player, receiving_group_id, receiving_group["local_items"])
if receiving_group["non_local_items"]:
for sending_player in world.player_ids:
if sending_player in receiving_group["players"]:
forbid(sending_player, receiving_group_id, receiving_group["non_local_items"])
# create fewer lambda's to save memory and cache misses
func_cache = {}
for location in world.get_locations():
if location.player != player:
forbid_items_for_player(location, world.local_items[player].value, player)
if world.non_local_items[player].value:
for location in world.get_locations():
if location.player == player:
forbid_items_for_player(location, world.non_local_items[player].value, player)
if (location.player, location.item_rule) in func_cache:
location.item_rule = func_cache[location.player, location.item_rule]
# empty rule that just returns True, overwrite
elif location.item_rule is location.__class__.item_rule:
func_cache[location.player, location.item_rule] = location.item_rule = \
lambda i, sending_blockers = forbid_data[location.player], \
old_rule = location.item_rule: \
i.name not in sending_blockers[i.player]
# special rule, needs to also be fulfilled.
else:
func_cache[location.player, location.item_rule] = location.item_rule = \
lambda i, sending_blockers = forbid_data[location.player], \
old_rule = location.item_rule: \
i.name not in sending_blockers[i.player] and old_rule(i)
def exclusion_rules(world: MultiWorld, player: int, exclude_locations: typing.Set[str]) -> None: