159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | from __future__ import annotations | ||
|  | 
 | ||
|  | from collections import OrderedDict | ||
|  | from threading import RLock | ||
|  | from typing import TYPE_CHECKING, Any, Iterable, Union | ||
|  | 
 | ||
|  | if TYPE_CHECKING: | ||
|  |     from .bot_ai import BotAI | ||
|  | 
 | ||
|  | 
 | ||
|  | class ExpiringDict(OrderedDict): | ||
|  |     """
 | ||
|  |     An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time. | ||
|  | 
 | ||
|  |     Example usages:: | ||
|  | 
 | ||
|  |         async def on_step(iteration: int): | ||
|  |             # This dict will hold up to 10 items and only return values that have been added up to 20 frames ago | ||
|  |             my_dict = ExpiringDict(self, max_age_frames=20) | ||
|  |             if iteration == 0: | ||
|  |                 # Add item | ||
|  |                 my_dict["test"] = "something" | ||
|  |             if iteration == 2: | ||
|  |                 # On default, one iteration is called every 8 frames | ||
|  |                 if "test" in my_dict: | ||
|  |                     print("test is in dict") | ||
|  |             if iteration == 20: | ||
|  |                 if "test" not in my_dict: | ||
|  |                     print("test is not anymore in dict") | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def __init__(self, bot: BotAI, max_age_frames: int = 1): | ||
|  |         assert max_age_frames >= -1 | ||
|  |         assert bot | ||
|  | 
 | ||
|  |         OrderedDict.__init__(self) | ||
|  |         self.bot: BotAI = bot | ||
|  |         self.max_age: Union[int, float] = max_age_frames | ||
|  |         self.lock: RLock = RLock() | ||
|  | 
 | ||
|  |     @property | ||
|  |     def frame(self) -> int: | ||
|  |         return self.bot.state.game_loop | ||
|  | 
 | ||
|  |     def __contains__(self, key) -> bool: | ||
|  |         """ Return True if dict has key, else False, e.g. 'key in dict' """ | ||
|  |         with self.lock: | ||
|  |             if OrderedDict.__contains__(self, key): | ||
|  |                 # Each item is a list of [value, frame time] | ||
|  |                 item = OrderedDict.__getitem__(self, key) | ||
|  |                 if self.frame - item[1] < self.max_age: | ||
|  |                     return True | ||
|  |                 del self[key] | ||
|  |         return False | ||
|  | 
 | ||
|  |     def __getitem__(self, key, with_age=False) -> Any: | ||
|  |         """ Return the item of the dict using d[key] """ | ||
|  |         with self.lock: | ||
|  |             # Each item is a list of [value, frame time] | ||
|  |             item = OrderedDict.__getitem__(self, key) | ||
|  |             if self.frame - item[1] < self.max_age: | ||
|  |                 if with_age: | ||
|  |                     return item[0], item[1] | ||
|  |                 return item[0] | ||
|  |             OrderedDict.__delitem__(self, key) | ||
|  |         raise KeyError(key) | ||
|  | 
 | ||
|  |     def __setitem__(self, key, value): | ||
|  |         """ Set d[key] = value """ | ||
|  |         with self.lock: | ||
|  |             OrderedDict.__setitem__(self, key, (value, self.frame)) | ||
|  | 
 | ||
|  |     def __repr__(self): | ||
|  |         """ Printable version of the dict instead of getting memory adress """ | ||
|  |         print_list = [] | ||
|  |         with self.lock: | ||
|  |             for key, value in OrderedDict.items(self): | ||
|  |                 if self.frame - value[1] < self.max_age: | ||
|  |                     print_list.append(f"{repr(key)}: {repr(value)}") | ||
|  |         print_str = ", ".join(print_list) | ||
|  |         return f"ExpiringDict({print_str})" | ||
|  | 
 | ||
|  |     def __str__(self): | ||
|  |         return self.__repr__() | ||
|  | 
 | ||
|  |     def __iter__(self): | ||
|  |         """ Override 'for key in dict:' """ | ||
|  |         with self.lock: | ||
|  |             return self.keys() | ||
|  | 
 | ||
|  |     # TODO find a way to improve len | ||
|  |     def __len__(self): | ||
|  |         """Override len method as key value pairs aren't instantly being deleted, but only on __get__(item).
 | ||
|  |         This function is slow because it has to check if each element is not expired yet."""
 | ||
|  |         with self.lock: | ||
|  |             count = 0 | ||
|  |             for _ in self.values(): | ||
|  |                 count += 1 | ||
|  |             return count | ||
|  | 
 | ||
|  |     def pop(self, key, default=None, with_age=False): | ||
|  |         """ Return the item and remove it """ | ||
|  |         with self.lock: | ||
|  |             if OrderedDict.__contains__(self, key): | ||
|  |                 item = OrderedDict.__getitem__(self, key) | ||
|  |                 if self.frame - item[1] < self.max_age: | ||
|  |                     del self[key] | ||
|  |                     if with_age: | ||
|  |                         return item[0], item[1] | ||
|  |                     return item[0] | ||
|  |                 del self[key] | ||
|  |             if default is None: | ||
|  |                 raise KeyError(key) | ||
|  |             if with_age: | ||
|  |                 return default, self.frame | ||
|  |             return default | ||
|  | 
 | ||
|  |     def get(self, key, default=None, with_age=False): | ||
|  |         """ Return the value for key if key is in dict, else default """ | ||
|  |         with self.lock: | ||
|  |             if OrderedDict.__contains__(self, key): | ||
|  |                 item = OrderedDict.__getitem__(self, key) | ||
|  |                 if self.frame - item[1] < self.max_age: | ||
|  |                     if with_age: | ||
|  |                         return item[0], item[1] | ||
|  |                     return item[0] | ||
|  |             if default is None: | ||
|  |                 raise KeyError(key) | ||
|  |             if with_age: | ||
|  |                 return default, self.frame | ||
|  |             return None | ||
|  |         return None | ||
|  | 
 | ||
|  |     def update(self, other_dict: dict): | ||
|  |         with self.lock: | ||
|  |             for key, value in other_dict.items(): | ||
|  |                 self[key] = value | ||
|  | 
 | ||
|  |     def items(self) -> Iterable: | ||
|  |         """ Return iterator of zipped list [keys, values] """ | ||
|  |         with self.lock: | ||
|  |             for key, value in OrderedDict.items(self): | ||
|  |                 if self.frame - value[1] < self.max_age: | ||
|  |                     yield key, value[0] | ||
|  | 
 | ||
|  |     def keys(self) -> Iterable: | ||
|  |         """ Return iterator of keys """ | ||
|  |         with self.lock: | ||
|  |             for key, value in OrderedDict.items(self): | ||
|  |                 if self.frame - value[1] < self.max_age: | ||
|  |                     yield key | ||
|  | 
 | ||
|  |     def values(self) -> Iterable: | ||
|  |         """ Return iterator of values """ | ||
|  |         with self.lock: | ||
|  |             for value in OrderedDict.values(self): | ||
|  |                 if self.frame - value[1] < self.max_age: | ||
|  |                     yield value[0] |