mirror of
https://github.com/MarioSpore/Grinch-AP.git
synced 2025-10-21 20:21:32 -06:00

GPL -> BSD2Clause and should be faster though I haven't tested it myself and just trusted people on the internet. Jellyfish also allows us access to many more algorithms should they be any better. Trying out Jaro distance now instead of Levenshtein.
503 lines
16 KiB
Python
503 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import shutil
|
|
import typing
|
|
import builtins
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import pickle
|
|
import functools
|
|
import io
|
|
import collections
|
|
import importlib
|
|
import logging
|
|
from tkinter import Tk
|
|
|
|
|
|
def tuplize_version(version: str) -> Version:
|
|
return Version(*(int(piece, 10) for piece in version.split(".")))
|
|
|
|
|
|
class Version(typing.NamedTuple):
|
|
major: int
|
|
minor: int
|
|
build: int
|
|
|
|
|
|
__version__ = "0.3.2"
|
|
version_tuple = tuplize_version(__version__)
|
|
|
|
import jellyfish
|
|
from yaml import load, load_all, dump, SafeLoader
|
|
|
|
try:
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Loader
|
|
|
|
|
|
def int16_as_bytes(value: int) -> typing.List[int]:
|
|
value = value & 0xFFFF
|
|
return [value & 0xFF, (value >> 8) & 0xFF]
|
|
|
|
|
|
def int32_as_bytes(value: int) -> typing.List[int]:
|
|
value = value & 0xFFFFFFFF
|
|
return [value & 0xFF, (value >> 8) & 0xFF, (value >> 16) & 0xFF, (value >> 24) & 0xFF]
|
|
|
|
|
|
def pc_to_snes(value: int) -> int:
|
|
return ((value << 1) & 0x7F0000) | (value & 0x7FFF) | 0x8000
|
|
|
|
|
|
def snes_to_pc(value: int) -> int:
|
|
return ((value & 0x7F0000) >> 1) | (value & 0x7FFF)
|
|
|
|
|
|
RetType = typing.TypeVar("RetType")
|
|
|
|
|
|
def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
|
|
assert not function.__code__.co_argcount, "Can only cache 0 argument functions with this cache."
|
|
|
|
sentinel = object()
|
|
result: typing.Union[object, RetType] = sentinel
|
|
|
|
def _wrap() -> RetType:
|
|
nonlocal result
|
|
if result is sentinel:
|
|
result = function()
|
|
return typing.cast(RetType, result)
|
|
|
|
return _wrap
|
|
|
|
|
|
def is_frozen() -> bool:
|
|
return typing.cast(bool, getattr(sys, 'frozen', False))
|
|
|
|
|
|
def local_path(*path: str) -> str:
|
|
"""Returns path to a file in the local Archipelago installation or source."""
|
|
if hasattr(local_path, 'cached_path'):
|
|
pass
|
|
elif is_frozen():
|
|
if hasattr(sys, "_MEIPASS"):
|
|
# we are running in a PyInstaller bundle
|
|
local_path.cached_path = sys._MEIPASS # pylint: disable=protected-access,no-member
|
|
else:
|
|
# cx_Freeze
|
|
local_path.cached_path = os.path.dirname(os.path.abspath(sys.argv[0]))
|
|
else:
|
|
import __main__
|
|
if hasattr(__main__, "__file__"):
|
|
# we are running in a normal Python environment
|
|
local_path.cached_path = os.path.dirname(os.path.abspath(__main__.__file__))
|
|
else:
|
|
# pray
|
|
local_path.cached_path = os.path.abspath(".")
|
|
|
|
return os.path.join(local_path.cached_path, *path)
|
|
|
|
|
|
def home_path(*path: str) -> str:
|
|
"""Returns path to a file in the user home's Archipelago directory."""
|
|
if hasattr(home_path, 'cached_path'):
|
|
pass
|
|
elif sys.platform.startswith('linux'):
|
|
home_path.cached_path = os.path.expanduser('~/Archipelago')
|
|
os.makedirs(home_path.cached_path, 0o700, exist_ok=True)
|
|
else:
|
|
# not implemented
|
|
home_path.cached_path = local_path() # this will generate the same exceptions we got previously
|
|
|
|
return os.path.join(home_path.cached_path, *path)
|
|
|
|
|
|
def user_path(*path: str) -> str:
|
|
"""Returns either local_path or home_path based on write permissions."""
|
|
if hasattr(user_path, 'cached_path'):
|
|
pass
|
|
elif os.access(local_path(), os.W_OK):
|
|
user_path.cached_path = local_path()
|
|
else:
|
|
user_path.cached_path = home_path()
|
|
# populate home from local - TODO: upgrade feature
|
|
if user_path.cached_path != local_path() and not os.path.exists(user_path('host.yaml')):
|
|
for dn in ('Players', 'data/sprites'):
|
|
shutil.copytree(local_path(dn), user_path(dn), dirs_exist_ok=True)
|
|
for fn in ('manifest.json', 'host.yaml'):
|
|
shutil.copy2(local_path(fn), user_path(fn))
|
|
|
|
return os.path.join(user_path.cached_path, *path)
|
|
|
|
|
|
def output_path(*path: str):
|
|
if hasattr(output_path, 'cached_path'):
|
|
return os.path.join(output_path.cached_path, *path)
|
|
output_path.cached_path = user_path(get_options()["general_options"]["output_path"])
|
|
path = os.path.join(output_path.cached_path, *path)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
return path
|
|
|
|
|
|
def open_file(filename):
|
|
if sys.platform == 'win32':
|
|
os.startfile(filename)
|
|
else:
|
|
open_command = 'open' if sys.platform == 'darwin' else 'xdg-open'
|
|
subprocess.call([open_command, filename])
|
|
|
|
|
|
# from https://gist.github.com/pypt/94d747fe5180851196eb#gistcomment-4015118 with some changes
|
|
class UniqueKeyLoader(SafeLoader):
|
|
def construct_mapping(self, node, deep=False):
|
|
mapping = set()
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep)
|
|
if key in mapping:
|
|
logging.error(f"YAML duplicates sanity check failed{key_node.start_mark}")
|
|
raise KeyError(f"Duplicate key {key} found in YAML. Already found keys: {mapping}.")
|
|
mapping.add(key)
|
|
return super().construct_mapping(node, deep)
|
|
|
|
|
|
parse_yaml = functools.partial(load, Loader=UniqueKeyLoader)
|
|
parse_yamls = functools.partial(load_all, Loader=UniqueKeyLoader)
|
|
unsafe_parse_yaml = functools.partial(load, Loader=Loader)
|
|
|
|
|
|
def get_cert_none_ssl_context():
|
|
import ssl
|
|
ctx = ssl.create_default_context()
|
|
ctx.check_hostname = False
|
|
ctx.verify_mode = ssl.CERT_NONE
|
|
return ctx
|
|
|
|
|
|
@cache_argsless
|
|
def get_public_ipv4() -> str:
|
|
import socket
|
|
import urllib.request
|
|
ip = socket.gethostbyname(socket.gethostname())
|
|
ctx = get_cert_none_ssl_context()
|
|
try:
|
|
ip = urllib.request.urlopen('https://checkip.amazonaws.com/', context=ctx).read().decode('utf8').strip()
|
|
except Exception as e:
|
|
try:
|
|
ip = urllib.request.urlopen('https://v4.ident.me', context=ctx).read().decode('utf8').strip()
|
|
except:
|
|
logging.exception(e)
|
|
pass # we could be offline, in a local game, so no point in erroring out
|
|
return ip
|
|
|
|
|
|
@cache_argsless
|
|
def get_public_ipv6() -> str:
|
|
import socket
|
|
import urllib.request
|
|
ip = socket.gethostbyname(socket.gethostname())
|
|
ctx = get_cert_none_ssl_context()
|
|
try:
|
|
ip = urllib.request.urlopen('https://v6.ident.me', context=ctx).read().decode('utf8').strip()
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
pass # we could be offline, in a local game, or ipv6 may not be available
|
|
return ip
|
|
|
|
|
|
@cache_argsless
|
|
def get_default_options() -> dict:
|
|
# Refer to host.yaml for comments as to what all these options mean.
|
|
options = {
|
|
"general_options": {
|
|
"output_path": "output",
|
|
},
|
|
"factorio_options": {
|
|
"executable": os.path.join("factorio", "bin", "x64", "factorio"),
|
|
},
|
|
"sm_options": {
|
|
"rom_file": "Super Metroid (JU).sfc",
|
|
"sni": "SNI",
|
|
"rom_start": True,
|
|
},
|
|
"soe_options": {
|
|
"rom_file": "Secret of Evermore (USA).sfc",
|
|
},
|
|
"lttp_options": {
|
|
"rom_file": "Zelda no Densetsu - Kamigami no Triforce (Japan).sfc",
|
|
"sni": "SNI",
|
|
"rom_start": True,
|
|
|
|
},
|
|
"server_options": {
|
|
"host": None,
|
|
"port": 38281,
|
|
"password": None,
|
|
"multidata": None,
|
|
"savefile": None,
|
|
"disable_save": False,
|
|
"loglevel": "info",
|
|
"server_password": None,
|
|
"disable_item_cheat": False,
|
|
"location_check_points": 1,
|
|
"hint_cost": 10,
|
|
"forfeit_mode": "goal",
|
|
"collect_mode": "disabled",
|
|
"remaining_mode": "goal",
|
|
"auto_shutdown": 0,
|
|
"compatibility": 2,
|
|
"log_network": 0
|
|
},
|
|
"generator": {
|
|
"teams": 1,
|
|
"enemizer_path": os.path.join("EnemizerCLI", "EnemizerCLI.Core.exe"),
|
|
"player_files_path": "Players",
|
|
"players": 0,
|
|
"weights_file_path": "weights.yaml",
|
|
"meta_file_path": "meta.yaml",
|
|
"spoiler": 2,
|
|
"glitch_triforce_room": 1,
|
|
"race": 0,
|
|
"plando_options": "bosses",
|
|
},
|
|
"minecraft_options": {
|
|
"forge_directory": "Minecraft Forge server",
|
|
"max_heap_size": "2G"
|
|
},
|
|
"oot_options": {
|
|
"rom_file": "The Legend of Zelda - Ocarina of Time.z64",
|
|
}
|
|
}
|
|
|
|
return options
|
|
|
|
|
|
def update_options(src: dict, dest: dict, filename: str, keys: list) -> dict:
|
|
for key, value in src.items():
|
|
new_keys = keys.copy()
|
|
new_keys.append(key)
|
|
option_name = '.'.join(new_keys)
|
|
if key not in dest:
|
|
dest[key] = value
|
|
if filename.endswith("options.yaml"):
|
|
logging.info(f"Warning: {filename} is missing {option_name}")
|
|
elif isinstance(value, dict):
|
|
if not isinstance(dest.get(key, None), dict):
|
|
if filename.endswith("options.yaml"):
|
|
logging.info(f"Warning: {filename} has {option_name}, but it is not a dictionary. overwriting.")
|
|
dest[key] = value
|
|
else:
|
|
dest[key] = update_options(value, dest[key], filename, new_keys)
|
|
return dest
|
|
|
|
|
|
@cache_argsless
|
|
def get_options() -> dict:
|
|
if not hasattr(get_options, "options"):
|
|
filenames = ("options.yaml", "host.yaml")
|
|
locations = []
|
|
if os.path.join(os.getcwd()) != local_path():
|
|
locations += filenames # use files from cwd only if it's not the local_path
|
|
locations += [user_path(filename) for filename in filenames]
|
|
|
|
for location in locations:
|
|
if os.path.exists(location):
|
|
with open(location) as f:
|
|
options = parse_yaml(f.read())
|
|
|
|
get_options.options = update_options(get_default_options(), options, location, list())
|
|
break
|
|
else:
|
|
raise FileNotFoundError(f"Could not find {filenames[1]} to load options.")
|
|
return get_options.options
|
|
|
|
|
|
def get_item_name_from_id(code: int) -> str:
|
|
from worlds import lookup_any_item_id_to_name
|
|
return lookup_any_item_id_to_name.get(code, f'Unknown item (ID:{code})')
|
|
|
|
|
|
def get_location_name_from_id(code: int) -> str:
|
|
from worlds import lookup_any_location_id_to_name
|
|
return lookup_any_location_id_to_name.get(code, f'Unknown location (ID:{code})')
|
|
|
|
|
|
def persistent_store(category: str, key: typing.Any, value: typing.Any):
|
|
path = user_path("_persistent_storage.yaml")
|
|
storage: dict = persistent_load()
|
|
category = storage.setdefault(category, {})
|
|
category[key] = value
|
|
with open(path, "wt") as f:
|
|
f.write(dump(storage))
|
|
|
|
|
|
def persistent_load() -> typing.Dict[dict]:
|
|
storage = getattr(persistent_load, "storage", None)
|
|
if storage:
|
|
return storage
|
|
path = user_path("_persistent_storage.yaml")
|
|
storage: dict = {}
|
|
if os.path.exists(path):
|
|
try:
|
|
with open(path, "r") as f:
|
|
storage = unsafe_parse_yaml(f.read())
|
|
except Exception as e:
|
|
logging.debug(f"Could not read store: {e}")
|
|
if storage is None:
|
|
storage = {}
|
|
persistent_load.storage = storage
|
|
return storage
|
|
|
|
|
|
def get_adjuster_settings(gameName: str):
|
|
adjuster_settings = persistent_load().get("adjuster", {}).get(gameName, {})
|
|
return adjuster_settings
|
|
|
|
|
|
@cache_argsless
|
|
def get_unique_identifier():
|
|
uuid = persistent_load().get("client", {}).get("uuid", None)
|
|
if uuid:
|
|
return uuid
|
|
|
|
import uuid
|
|
uuid = uuid.getnode()
|
|
persistent_store("client", "uuid", uuid)
|
|
return uuid
|
|
|
|
|
|
safe_builtins = {
|
|
'set',
|
|
'frozenset',
|
|
}
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def __init__(self, *args, **kwargs):
|
|
super(RestrictedUnpickler, self).__init__(*args, **kwargs)
|
|
self.options_module = importlib.import_module("Options")
|
|
self.net_utils_module = importlib.import_module("NetUtils")
|
|
self.generic_properties_module = importlib.import_module("worlds.generic")
|
|
|
|
def find_class(self, module, name):
|
|
if module == "builtins" and name in safe_builtins:
|
|
return getattr(builtins, name)
|
|
# used by MultiServer -> savegame/multidata
|
|
if module == "NetUtils" and name in {"NetworkItem", "ClientStatus", "Hint", "SlotType", "NetworkSlot"}:
|
|
return getattr(self.net_utils_module, name)
|
|
# Options and Plando are unpickled by WebHost -> Generate
|
|
if module == "worlds.generic" and name in {"PlandoItem", "PlandoConnection"}:
|
|
return getattr(self.generic_properties_module, name)
|
|
if module.endswith("Options"):
|
|
if module == "Options":
|
|
mod = self.options_module
|
|
else:
|
|
mod = importlib.import_module(module)
|
|
obj = getattr(mod, name)
|
|
if issubclass(obj, self.options_module.Option):
|
|
return obj
|
|
# Forbid everything else.
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
(module, name))
|
|
|
|
|
|
def restricted_loads(s):
|
|
"""Helper function analogous to pickle.loads()."""
|
|
return RestrictedUnpickler(io.BytesIO(s)).load()
|
|
|
|
|
|
class KeyedDefaultDict(collections.defaultdict):
|
|
def __missing__(self, key):
|
|
self[key] = value = self.default_factory(key)
|
|
return value
|
|
|
|
|
|
def get_text_between(text: str, start: str, end: str) -> str:
|
|
return text[text.index(start) + len(start): text.rindex(end)]
|
|
|
|
|
|
loglevel_mapping = {'error': logging.ERROR, 'info': logging.INFO, 'warning': logging.WARNING, 'debug': logging.DEBUG}
|
|
|
|
|
|
def init_logging(name: str, loglevel: typing.Union[str, int] = logging.INFO, write_mode: str = "w",
|
|
log_format: str = "[%(name)s at %(asctime)s]: %(message)s", exception_logger: str = ""):
|
|
loglevel: int = loglevel_mapping.get(loglevel, loglevel)
|
|
log_folder = user_path("logs")
|
|
os.makedirs(log_folder, exist_ok=True)
|
|
root_logger = logging.getLogger()
|
|
for handler in root_logger.handlers[:]:
|
|
root_logger.removeHandler(handler)
|
|
handler.close()
|
|
root_logger.setLevel(loglevel)
|
|
file_handler = logging.FileHandler(
|
|
os.path.join(log_folder, f"{name}.txt"),
|
|
write_mode,
|
|
encoding="utf-8-sig")
|
|
file_handler.setFormatter(logging.Formatter(log_format))
|
|
root_logger.addHandler(file_handler)
|
|
if sys.stdout:
|
|
root_logger.addHandler(
|
|
logging.StreamHandler(sys.stdout)
|
|
)
|
|
|
|
# Relay unhandled exceptions to logger.
|
|
if not getattr(sys.excepthook, "_wrapped", False): # skip if already modified
|
|
orig_hook = sys.excepthook
|
|
|
|
def handle_exception(exc_type, exc_value, exc_traceback):
|
|
if issubclass(exc_type, KeyboardInterrupt):
|
|
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
|
return
|
|
logging.getLogger(exception_logger).exception("Uncaught exception",
|
|
exc_info=(exc_type, exc_value, exc_traceback))
|
|
return orig_hook(exc_type, exc_value, exc_traceback)
|
|
|
|
handle_exception._wrapped = True
|
|
|
|
sys.excepthook = handle_exception
|
|
|
|
|
|
def stream_input(stream, queue):
|
|
def queuer():
|
|
while 1:
|
|
text = stream.readline().strip()
|
|
if text:
|
|
queue.put_nowait(text)
|
|
|
|
from threading import Thread
|
|
thread = Thread(target=queuer, name=f"Stream handler for {stream.name}", daemon=True)
|
|
thread.start()
|
|
return thread
|
|
|
|
|
|
def tkinter_center_window(window: Tk):
|
|
window.update()
|
|
xPos = int(window.winfo_screenwidth() / 2 - window.winfo_reqwidth() / 2)
|
|
yPos = int(window.winfo_screenheight() / 2 - window.winfo_reqheight() / 2)
|
|
window.geometry("+{}+{}".format(xPos, yPos))
|
|
|
|
|
|
class VersionException(Exception):
|
|
pass
|
|
|
|
|
|
# noinspection PyPep8Naming
|
|
def format_SI_prefix(value, power=1000, power_labels=('', 'k', 'M', 'G', 'T', "P", "E", "Z", "Y")) -> str:
|
|
n = 0
|
|
|
|
while value > power:
|
|
value /= power
|
|
n += 1
|
|
if type(value) == int:
|
|
return f"{value} {power_labels[n]}"
|
|
else:
|
|
return f"{value:0.3f} {power_labels[n]}"
|
|
|
|
|
|
def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str]) -> typing.List[typing.Tuple[str, int]]:
|
|
return sorted(
|
|
map(lambda candidate: (candidate, int(100*jellyfish.jaro_winkler_similarity(input_word, candidate))), wordlist),
|
|
key=lambda element: element[1],
|
|
reverse=True)
|