mirror of
https://github.com/MarioSpore/Grinch-AP.git
synced 2025-10-21 20:21:32 -06:00
Use a proper multiset for progression items
This can cut generation times in half in some cases
This commit is contained in:
552
_vendor/collections_extended/setlists.py
Normal file
552
_vendor/collections_extended/setlists.py
Normal file
@@ -0,0 +1,552 @@
|
||||
"""Setlist class definitions."""
|
||||
import random as random_
|
||||
|
||||
from collections import (
|
||||
Sequence,
|
||||
Set,
|
||||
MutableSequence,
|
||||
MutableSet,
|
||||
Hashable,
|
||||
)
|
||||
|
||||
from . import _util
|
||||
|
||||
|
||||
class _basesetlist(Sequence, Set):
|
||||
"""A setlist is an ordered Collection of unique elements.
|
||||
|
||||
_basesetlist is the superclass of setlist and frozensetlist. It is immutable
|
||||
and unhashable.
|
||||
"""
|
||||
|
||||
def __init__(self, iterable=None, raise_on_duplicate=False):
|
||||
"""Create a setlist.
|
||||
|
||||
Args:
|
||||
iterable (Iterable): Values to initialize the setlist with.
|
||||
"""
|
||||
self._list = list()
|
||||
self._dict = dict()
|
||||
if iterable:
|
||||
if raise_on_duplicate:
|
||||
self._extend(iterable)
|
||||
else:
|
||||
self._update(iterable)
|
||||
|
||||
def __repr__(self):
|
||||
if len(self) == 0:
|
||||
return '{0}()'.format(self.__class__.__name__)
|
||||
else:
|
||||
repr_format = '{class_name}({values!r})'
|
||||
return repr_format.format(
|
||||
class_name=self.__class__.__name__,
|
||||
values=tuple(self),
|
||||
)
|
||||
|
||||
# Convenience methods
|
||||
def _fix_neg_index(self, index):
|
||||
if index < 0:
|
||||
index += len(self)
|
||||
if index < 0:
|
||||
raise IndexError('index is out of range')
|
||||
return index
|
||||
|
||||
def _fix_end_index(self, index):
|
||||
if index is None:
|
||||
return len(self)
|
||||
else:
|
||||
return self._fix_neg_index(index)
|
||||
|
||||
def _append(self, value):
|
||||
# Checking value in self will check that value is Hashable
|
||||
if value in self:
|
||||
raise ValueError('Value "%s" already present' % str(value))
|
||||
else:
|
||||
self._dict[value] = len(self)
|
||||
self._list.append(value)
|
||||
|
||||
def _extend(self, values):
|
||||
new_values = set()
|
||||
for value in values:
|
||||
if value in new_values:
|
||||
raise ValueError('New values contain duplicates')
|
||||
elif value in self:
|
||||
raise ValueError('New values contain elements already present in self')
|
||||
else:
|
||||
new_values.add(value)
|
||||
for value in values:
|
||||
self._dict[value] = len(self)
|
||||
self._list.append(value)
|
||||
|
||||
def _add(self, item):
|
||||
if item not in self:
|
||||
self._dict[item] = len(self)
|
||||
self._list.append(item)
|
||||
|
||||
def _update(self, values):
|
||||
for value in values:
|
||||
if value not in self:
|
||||
self._dict[value] = len(self)
|
||||
self._list.append(value)
|
||||
|
||||
@classmethod
|
||||
def _from_iterable(cls, it, **kwargs):
|
||||
return cls(it, **kwargs)
|
||||
|
||||
# Implement Container
|
||||
def __contains__(self, value):
|
||||
return value in self._dict
|
||||
|
||||
# Iterable we get by inheriting from Sequence
|
||||
|
||||
# Implement Sized
|
||||
def __len__(self):
|
||||
return len(self._list)
|
||||
|
||||
# Implement Sequence
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return self._from_iterable(self._list[index])
|
||||
return self._list[index]
|
||||
|
||||
def count(self, value):
|
||||
"""Return the number of occurences of value in self.
|
||||
|
||||
This runs in O(1)
|
||||
|
||||
Args:
|
||||
value: The value to count
|
||||
Returns:
|
||||
int: 1 if the value is in the setlist, otherwise 0
|
||||
"""
|
||||
if value in self:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def index(self, value, start=0, end=None):
|
||||
"""Return the index of value between start and end.
|
||||
|
||||
By default, the entire setlist is searched.
|
||||
|
||||
This runs in O(1)
|
||||
|
||||
Args:
|
||||
value: The value to find the index of
|
||||
start (int): The index to start searching at (defaults to 0)
|
||||
end (int): The index to stop searching at (defaults to the end of the list)
|
||||
Returns:
|
||||
int: The index of the value
|
||||
Raises:
|
||||
ValueError: If the value is not in the list or outside of start - end
|
||||
IndexError: If start or end are out of range
|
||||
"""
|
||||
try:
|
||||
index = self._dict[value]
|
||||
except KeyError:
|
||||
raise ValueError
|
||||
else:
|
||||
start = self._fix_neg_index(start)
|
||||
end = self._fix_end_index(end)
|
||||
if start <= index and index < end:
|
||||
return index
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
@classmethod
|
||||
def _check_type(cls, other, operand_name):
|
||||
if not isinstance(other, _basesetlist):
|
||||
message = (
|
||||
"unsupported operand type(s) for {operand_name}: "
|
||||
"'{self_type}' and '{other_type}'").format(
|
||||
operand_name=operand_name,
|
||||
self_type=cls,
|
||||
other_type=type(other),
|
||||
)
|
||||
raise TypeError(message)
|
||||
|
||||
def __add__(self, other):
|
||||
self._check_type(other, '+')
|
||||
out = self.copy()
|
||||
out._extend(other)
|
||||
return out
|
||||
|
||||
# Implement Set
|
||||
|
||||
def issubset(self, other):
|
||||
return self <= other
|
||||
|
||||
def issuperset(self, other):
|
||||
return self >= other
|
||||
|
||||
def union(self, other):
|
||||
out = self.copy()
|
||||
out.update(other)
|
||||
return out
|
||||
|
||||
def intersection(self, other):
|
||||
other = set(other)
|
||||
return self._from_iterable(item for item in self if item in other)
|
||||
|
||||
def difference(self, other):
|
||||
other = set(other)
|
||||
return self._from_iterable(item for item in self if item not in other)
|
||||
|
||||
def symmetric_difference(self, other):
|
||||
return self.union(other) - self.intersection(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
self._check_type(other, '-')
|
||||
return self.difference(other)
|
||||
|
||||
def __and__(self, other):
|
||||
self._check_type(other, '&')
|
||||
return self.intersection(other)
|
||||
|
||||
def __or__(self, other):
|
||||
self._check_type(other, '|')
|
||||
return self.union(other)
|
||||
|
||||
def __xor__(self, other):
|
||||
self._check_type(other, '^')
|
||||
return self.symmetric_difference(other)
|
||||
|
||||
# Comparison
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _basesetlist):
|
||||
return False
|
||||
if not len(self) == len(other):
|
||||
return False
|
||||
for self_elem, other_elem in zip(self, other):
|
||||
if self_elem != other_elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
# New methods
|
||||
|
||||
def sub_index(self, sub, start=0, end=None):
|
||||
"""Return the index of a subsequence.
|
||||
|
||||
This runs in O(len(sub))
|
||||
|
||||
Args:
|
||||
sub (Sequence): An Iterable to search for
|
||||
Returns:
|
||||
int: The index of the first element of sub
|
||||
Raises:
|
||||
ValueError: If sub isn't a subsequence
|
||||
TypeError: If sub isn't iterable
|
||||
IndexError: If start or end are out of range
|
||||
"""
|
||||
start_index = self.index(sub[0], start, end)
|
||||
end = self._fix_end_index(end)
|
||||
if start_index + len(sub) > end:
|
||||
raise ValueError
|
||||
for i in range(1, len(sub)):
|
||||
if sub[i] != self[start_index + i]:
|
||||
raise ValueError
|
||||
return start_index
|
||||
|
||||
def copy(self):
|
||||
return self.__class__(self)
|
||||
|
||||
|
||||
class setlist(_basesetlist, MutableSequence, MutableSet):
|
||||
"""A mutable (unhashable) setlist."""
|
||||
|
||||
def __str__(self):
|
||||
return '{[%s}]' % ', '.join(repr(v) for v in self)
|
||||
|
||||
# Helper methods
|
||||
def _delete_all(self, elems_to_delete, raise_errors):
|
||||
indices_to_delete = set()
|
||||
for elem in elems_to_delete:
|
||||
try:
|
||||
elem_index = self._dict[elem]
|
||||
except KeyError:
|
||||
if raise_errors:
|
||||
raise ValueError('Passed values contain elements not in self')
|
||||
else:
|
||||
if elem_index in indices_to_delete:
|
||||
if raise_errors:
|
||||
raise ValueError('Passed vales contain duplicates')
|
||||
indices_to_delete.add(elem_index)
|
||||
self._delete_values_by_index(indices_to_delete)
|
||||
|
||||
def _delete_values_by_index(self, indices_to_delete):
|
||||
deleted_count = 0
|
||||
for i, elem in enumerate(self._list):
|
||||
if i in indices_to_delete:
|
||||
deleted_count += 1
|
||||
del self._dict[elem]
|
||||
else:
|
||||
new_index = i - deleted_count
|
||||
self._list[new_index] = elem
|
||||
self._dict[elem] = new_index
|
||||
# Now remove deleted_count items from the end of the list
|
||||
if deleted_count:
|
||||
self._list = self._list[:-deleted_count]
|
||||
|
||||
# Set/Sequence agnostic
|
||||
def pop(self, index=-1):
|
||||
"""Remove and return the item at index."""
|
||||
value = self._list.pop(index)
|
||||
del self._dict[value]
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
"""Remove all elements from self."""
|
||||
self._dict = dict()
|
||||
self._list = list()
|
||||
|
||||
# Implement MutableSequence
|
||||
def __setitem__(self, index, value):
|
||||
if isinstance(index, slice):
|
||||
old_values = self[index]
|
||||
for v in value:
|
||||
if v in self and v not in old_values:
|
||||
raise ValueError
|
||||
self._list[index] = value
|
||||
self._dict = {}
|
||||
for i, v in enumerate(self._list):
|
||||
self._dict[v] = i
|
||||
else:
|
||||
index = self._fix_neg_index(index)
|
||||
old_value = self._list[index]
|
||||
if value in self:
|
||||
if value == old_value:
|
||||
return
|
||||
else:
|
||||
raise ValueError
|
||||
del self._dict[old_value]
|
||||
self._list[index] = value
|
||||
self._dict[value] = index
|
||||
|
||||
def __delitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
indices_to_delete = set(self.index(e) for e in self._list[index])
|
||||
self._delete_values_by_index(indices_to_delete)
|
||||
else:
|
||||
index = self._fix_neg_index(index)
|
||||
value = self._list[index]
|
||||
del self._dict[value]
|
||||
for elem in self._list[index + 1:]:
|
||||
self._dict[elem] -= 1
|
||||
del self._list[index]
|
||||
|
||||
def insert(self, index, value):
|
||||
"""Insert value at index.
|
||||
|
||||
Args:
|
||||
index (int): Index to insert value at
|
||||
value: Value to insert
|
||||
Raises:
|
||||
ValueError: If value already in self
|
||||
IndexError: If start or end are out of range
|
||||
"""
|
||||
if value in self:
|
||||
raise ValueError
|
||||
index = self._fix_neg_index(index)
|
||||
self._dict[value] = index
|
||||
for elem in self._list[index:]:
|
||||
self._dict[elem] += 1
|
||||
self._list.insert(index, value)
|
||||
|
||||
def append(self, value):
|
||||
"""Append value to the end.
|
||||
|
||||
Args:
|
||||
value: Value to append
|
||||
Raises:
|
||||
ValueError: If value alread in self
|
||||
TypeError: If value isn't hashable
|
||||
"""
|
||||
self._append(value)
|
||||
|
||||
def extend(self, values):
|
||||
"""Append all values to the end.
|
||||
|
||||
If any of the values are present, ValueError will
|
||||
be raised and none of the values will be appended.
|
||||
|
||||
Args:
|
||||
values (Iterable): Values to append
|
||||
Raises:
|
||||
ValueError: If any values are already present or there are duplicates
|
||||
in the passed values.
|
||||
TypeError: If any of the values aren't hashable.
|
||||
"""
|
||||
self._extend(values)
|
||||
|
||||
def __iadd__(self, values):
|
||||
"""Add all values to the end of self.
|
||||
|
||||
Args:
|
||||
values (Iterable): Values to append
|
||||
Raises:
|
||||
ValueError: If any values are already present
|
||||
"""
|
||||
self._check_type(values, '+=')
|
||||
self.extend(values)
|
||||
return self
|
||||
|
||||
def remove(self, value):
|
||||
"""Remove value from self.
|
||||
|
||||
Args:
|
||||
value: Element to remove from self
|
||||
Raises:
|
||||
ValueError: if element is already present
|
||||
"""
|
||||
try:
|
||||
index = self._dict[value]
|
||||
except KeyError:
|
||||
raise ValueError('Value "%s" is not present.')
|
||||
else:
|
||||
del self[index]
|
||||
|
||||
def remove_all(self, elems_to_delete):
|
||||
"""Remove all elements from elems_to_delete, raises ValueErrors.
|
||||
|
||||
See Also:
|
||||
discard_all
|
||||
Args:
|
||||
elems_to_delete (Iterable): Elements to remove.
|
||||
Raises:
|
||||
ValueError: If the count of any element is greater in
|
||||
elems_to_delete than self.
|
||||
TypeError: If any of the values aren't hashable.
|
||||
"""
|
||||
self._delete_all(elems_to_delete, raise_errors=True)
|
||||
|
||||
# Implement MutableSet
|
||||
|
||||
def add(self, item):
|
||||
"""Add an item.
|
||||
|
||||
Note:
|
||||
This does not raise a ValueError for an already present value like
|
||||
append does. This is to match the behavior of set.add
|
||||
Args:
|
||||
item: Item to add
|
||||
Raises:
|
||||
TypeError: If item isn't hashable.
|
||||
"""
|
||||
self._add(item)
|
||||
|
||||
def update(self, values):
|
||||
"""Add all values to the end.
|
||||
|
||||
If any of the values are present, silently ignore
|
||||
them (as opposed to extend which raises an Error).
|
||||
|
||||
See also:
|
||||
extend
|
||||
Args:
|
||||
values (Iterable): Values to add
|
||||
Raises:
|
||||
TypeError: If any of the values are unhashable.
|
||||
"""
|
||||
self._update(values)
|
||||
|
||||
def discard_all(self, elems_to_delete):
|
||||
"""Discard all the elements from elems_to_delete.
|
||||
|
||||
This is much faster than removing them one by one.
|
||||
This runs in O(len(self) + len(elems_to_delete))
|
||||
|
||||
Args:
|
||||
elems_to_delete (Iterable): Elements to discard.
|
||||
Raises:
|
||||
TypeError: If any of the values aren't hashable.
|
||||
"""
|
||||
self._delete_all(elems_to_delete, raise_errors=False)
|
||||
|
||||
def discard(self, value):
|
||||
"""Discard an item.
|
||||
|
||||
Note:
|
||||
This does not raise a ValueError for a missing value like remove does.
|
||||
This is to match the behavior of set.discard
|
||||
"""
|
||||
try:
|
||||
self.remove(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def difference_update(self, other):
|
||||
"""Update self to include only the differene with other."""
|
||||
other = set(other)
|
||||
indices_to_delete = set()
|
||||
for i, elem in enumerate(self):
|
||||
if elem in other:
|
||||
indices_to_delete.add(i)
|
||||
if indices_to_delete:
|
||||
self._delete_values_by_index(indices_to_delete)
|
||||
|
||||
def intersection_update(self, other):
|
||||
"""Update self to include only the intersection with other."""
|
||||
other = set(other)
|
||||
indices_to_delete = set()
|
||||
for i, elem in enumerate(self):
|
||||
if elem not in other:
|
||||
indices_to_delete.add(i)
|
||||
if indices_to_delete:
|
||||
self._delete_values_by_index(indices_to_delete)
|
||||
|
||||
def symmetric_difference_update(self, other):
|
||||
"""Update self to include only the symmetric difference with other."""
|
||||
other = setlist(other)
|
||||
indices_to_delete = set()
|
||||
for i, item in enumerate(self):
|
||||
if item in other:
|
||||
indices_to_delete.add(i)
|
||||
for item in other:
|
||||
self.add(item)
|
||||
self._delete_values_by_index(indices_to_delete)
|
||||
|
||||
def __isub__(self, other):
|
||||
self._check_type(other, '-=')
|
||||
self.difference_update(other)
|
||||
return self
|
||||
|
||||
def __iand__(self, other):
|
||||
self._check_type(other, '&=')
|
||||
self.intersection_update(other)
|
||||
return self
|
||||
|
||||
def __ior__(self, other):
|
||||
self._check_type(other, '|=')
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
def __ixor__(self, other):
|
||||
self._check_type(other, '^=')
|
||||
self.symmetric_difference_update(other)
|
||||
return self
|
||||
|
||||
# New methods
|
||||
def shuffle(self, random=None):
|
||||
"""Shuffle all of the elements in self randomly."""
|
||||
random_.shuffle(self._list, random=random)
|
||||
for i, elem in enumerate(self._list):
|
||||
self._dict[elem] = i
|
||||
|
||||
def sort(self, *args, **kwargs):
|
||||
"""Sort this setlist in place."""
|
||||
self._list.sort(*args, **kwargs)
|
||||
for index, value in enumerate(self._list):
|
||||
self._dict[value] = index
|
||||
|
||||
|
||||
class frozensetlist(_basesetlist, Hashable):
|
||||
"""An immutable (hashable) setlist."""
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash_value'):
|
||||
self._hash_value = _util.hash_iterable(self)
|
||||
return self._hash_value
|
||||
Reference in New Issue
Block a user