158
worlds/_sc2common/bot/expiring_dict.py
Normal file
158
worlds/_sc2common/bot/expiring_dict.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bot_ai import BotAI
|
||||
|
||||
|
||||
class ExpiringDict(OrderedDict):
|
||||
"""
|
||||
An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time.
|
||||
|
||||
Example usages::
|
||||
|
||||
async def on_step(iteration: int):
|
||||
# This dict will hold up to 10 items and only return values that have been added up to 20 frames ago
|
||||
my_dict = ExpiringDict(self, max_age_frames=20)
|
||||
if iteration == 0:
|
||||
# Add item
|
||||
my_dict["test"] = "something"
|
||||
if iteration == 2:
|
||||
# On default, one iteration is called every 8 frames
|
||||
if "test" in my_dict:
|
||||
print("test is in dict")
|
||||
if iteration == 20:
|
||||
if "test" not in my_dict:
|
||||
print("test is not anymore in dict")
|
||||
"""
|
||||
|
||||
def __init__(self, bot: BotAI, max_age_frames: int = 1):
|
||||
assert max_age_frames >= -1
|
||||
assert bot
|
||||
|
||||
OrderedDict.__init__(self)
|
||||
self.bot: BotAI = bot
|
||||
self.max_age: Union[int, float] = max_age_frames
|
||||
self.lock: RLock = RLock()
|
||||
|
||||
@property
|
||||
def frame(self) -> int:
|
||||
return self.bot.state.game_loop
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
""" Return True if dict has key, else False, e.g. 'key in dict' """
|
||||
with self.lock:
|
||||
if OrderedDict.__contains__(self, key):
|
||||
# Each item is a list of [value, frame time]
|
||||
item = OrderedDict.__getitem__(self, key)
|
||||
if self.frame - item[1] < self.max_age:
|
||||
return True
|
||||
del self[key]
|
||||
return False
|
||||
|
||||
def __getitem__(self, key, with_age=False) -> Any:
|
||||
""" Return the item of the dict using d[key] """
|
||||
with self.lock:
|
||||
# Each item is a list of [value, frame time]
|
||||
item = OrderedDict.__getitem__(self, key)
|
||||
if self.frame - item[1] < self.max_age:
|
||||
if with_age:
|
||||
return item[0], item[1]
|
||||
return item[0]
|
||||
OrderedDict.__delitem__(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
""" Set d[key] = value """
|
||||
with self.lock:
|
||||
OrderedDict.__setitem__(self, key, (value, self.frame))
|
||||
|
||||
def __repr__(self):
|
||||
""" Printable version of the dict instead of getting memory adress """
|
||||
print_list = []
|
||||
with self.lock:
|
||||
for key, value in OrderedDict.items(self):
|
||||
if self.frame - value[1] < self.max_age:
|
||||
print_list.append(f"{repr(key)}: {repr(value)}")
|
||||
print_str = ", ".join(print_list)
|
||||
return f"ExpiringDict({print_str})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __iter__(self):
|
||||
""" Override 'for key in dict:' """
|
||||
with self.lock:
|
||||
return self.keys()
|
||||
|
||||
# TODO find a way to improve len
|
||||
def __len__(self):
|
||||
"""Override len method as key value pairs aren't instantly being deleted, but only on __get__(item).
|
||||
This function is slow because it has to check if each element is not expired yet."""
|
||||
with self.lock:
|
||||
count = 0
|
||||
for _ in self.values():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def pop(self, key, default=None, with_age=False):
|
||||
""" Return the item and remove it """
|
||||
with self.lock:
|
||||
if OrderedDict.__contains__(self, key):
|
||||
item = OrderedDict.__getitem__(self, key)
|
||||
if self.frame - item[1] < self.max_age:
|
||||
del self[key]
|
||||
if with_age:
|
||||
return item[0], item[1]
|
||||
return item[0]
|
||||
del self[key]
|
||||
if default is None:
|
||||
raise KeyError(key)
|
||||
if with_age:
|
||||
return default, self.frame
|
||||
return default
|
||||
|
||||
def get(self, key, default=None, with_age=False):
|
||||
""" Return the value for key if key is in dict, else default """
|
||||
with self.lock:
|
||||
if OrderedDict.__contains__(self, key):
|
||||
item = OrderedDict.__getitem__(self, key)
|
||||
if self.frame - item[1] < self.max_age:
|
||||
if with_age:
|
||||
return item[0], item[1]
|
||||
return item[0]
|
||||
if default is None:
|
||||
raise KeyError(key)
|
||||
if with_age:
|
||||
return default, self.frame
|
||||
return None
|
||||
return None
|
||||
|
||||
def update(self, other_dict: dict):
|
||||
with self.lock:
|
||||
for key, value in other_dict.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self) -> Iterable:
|
||||
""" Return iterator of zipped list [keys, values] """
|
||||
with self.lock:
|
||||
for key, value in OrderedDict.items(self):
|
||||
if self.frame - value[1] < self.max_age:
|
||||
yield key, value[0]
|
||||
|
||||
def keys(self) -> Iterable:
|
||||
""" Return iterator of keys """
|
||||
with self.lock:
|
||||
for key, value in OrderedDict.items(self):
|
||||
if self.frame - value[1] < self.max_age:
|
||||
yield key
|
||||
|
||||
def values(self) -> Iterable:
|
||||
""" Return iterator of values """
|
||||
with self.lock:
|
||||
for value in OrderedDict.values(self):
|
||||
if self.frame - value[1] < self.max_age:
|
||||
yield value[0]
|
||||
Reference in New Issue
Block a user