mirror of
https://github.com/MarioSpore/Grinch-AP.git
synced 2025-10-21 12:11:33 -06:00
Core: some typing and docs in various parts of the interface (#1060)
* some typing and docs in various parts of the interface * fix whitespace in docstring Co-authored-by: black-sliver <59490463+black-sliver@users.noreply.github.com> * suggested changes from discussion * remove redundant import * adjust type for json messages * for options module detection: module.lower().endswith("options") Co-authored-by: black-sliver <59490463+black-sliver@users.noreply.github.com>
This commit is contained in:
89
Patch.py
89
Patch.py
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import json
|
||||
import bsdiff4
|
||||
import bsdiff4 # type: ignore
|
||||
import yaml
|
||||
import os
|
||||
import lzma
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
import concurrent.futures
|
||||
import zipfile
|
||||
import sys
|
||||
from typing import Tuple, Optional, Dict, Any, Union, BinaryIO
|
||||
from typing import ClassVar, List, Tuple, Optional, Dict, Any, Union, BinaryIO
|
||||
|
||||
import ModuleUpdate
|
||||
ModuleUpdate.update()
|
||||
@@ -21,10 +21,10 @@ current_patch_version = 5
|
||||
|
||||
|
||||
class AutoPatchRegister(type):
|
||||
patch_types: Dict[str, APDeltaPatch] = {}
|
||||
file_endings: Dict[str, APDeltaPatch] = {}
|
||||
patch_types: ClassVar[Dict[str, AutoPatchRegister]] = {}
|
||||
file_endings: ClassVar[Dict[str, AutoPatchRegister]] = {}
|
||||
|
||||
def __new__(cls, name: str, bases, dct: Dict[str, Any]):
|
||||
def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchRegister:
|
||||
# construct class
|
||||
new_class = super().__new__(cls, name, bases, dct)
|
||||
if "game" in dct:
|
||||
@@ -35,10 +35,11 @@ class AutoPatchRegister(type):
|
||||
return new_class
|
||||
|
||||
@staticmethod
|
||||
def get_handler(file: str) -> Optional[type(APDeltaPatch)]:
|
||||
def get_handler(file: str) -> Optional[AutoPatchRegister]:
|
||||
for file_ending, handler in AutoPatchRegister.file_endings.items():
|
||||
if file.endswith(file_ending):
|
||||
return handler
|
||||
return None
|
||||
|
||||
|
||||
class APContainer:
|
||||
@@ -61,34 +62,36 @@ class APContainer:
|
||||
self.player_name = player_name
|
||||
self.server = server
|
||||
|
||||
def write(self, file: Optional[Union[str, BinaryIO]] = None):
|
||||
if not self.path and not file:
|
||||
def write(self, file: Optional[Union[str, BinaryIO]] = None) -> None:
|
||||
zip_file = file if file else self.path
|
||||
if not zip_file:
|
||||
raise FileNotFoundError(f"Cannot write {self.__class__.__name__} due to no path provided.")
|
||||
with zipfile.ZipFile(file if file else self.path, "w", self.compression_method, True, self.compression_level) \
|
||||
with zipfile.ZipFile(zip_file, "w", self.compression_method, True, self.compression_level) \
|
||||
as zf:
|
||||
if file:
|
||||
self.path = zf.filename
|
||||
self.write_contents(zf)
|
||||
|
||||
def write_contents(self, opened_zipfile: zipfile.ZipFile):
|
||||
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
|
||||
manifest = self.get_manifest()
|
||||
try:
|
||||
manifest = json.dumps(manifest)
|
||||
manifest_str = json.dumps(manifest)
|
||||
except Exception as e:
|
||||
raise Exception(f"Manifest {manifest} did not convert to json.") from e
|
||||
else:
|
||||
opened_zipfile.writestr("archipelago.json", manifest)
|
||||
opened_zipfile.writestr("archipelago.json", manifest_str)
|
||||
|
||||
def read(self, file: Optional[Union[str, BinaryIO]] = None):
|
||||
def read(self, file: Optional[Union[str, BinaryIO]] = None) -> None:
|
||||
"""Read data into patch object. file can be file-like, such as an outer zip file's stream."""
|
||||
if not self.path and not file:
|
||||
zip_file = file if file else self.path
|
||||
if not zip_file:
|
||||
raise FileNotFoundError(f"Cannot read {self.__class__.__name__} due to no path provided.")
|
||||
with zipfile.ZipFile(file if file else self.path, "r") as zf:
|
||||
with zipfile.ZipFile(zip_file, "r") as zf:
|
||||
if file:
|
||||
self.path = zf.filename
|
||||
self.read_contents(zf)
|
||||
|
||||
def read_contents(self, opened_zipfile: zipfile.ZipFile):
|
||||
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
|
||||
with opened_zipfile.open("archipelago.json", "r") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest["compatible_version"] > self.version:
|
||||
@@ -98,7 +101,7 @@ class APContainer:
|
||||
self.server = manifest["server"]
|
||||
self.player_name = manifest["player_name"]
|
||||
|
||||
def get_manifest(self) -> dict:
|
||||
def get_manifest(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"server": self.server, # allow immediate connection to server in multiworld. Empty string otherwise
|
||||
"player": self.player,
|
||||
@@ -114,17 +117,17 @@ class APDeltaPatch(APContainer, metaclass=AutoPatchRegister):
|
||||
"""An APContainer that additionally has delta.bsdiff4
|
||||
containing a delta patch to get the desired file, often a rom."""
|
||||
|
||||
hash = Optional[str] # base checksum of source file
|
||||
hash: Optional[str] # base checksum of source file
|
||||
patch_file_ending: str = ""
|
||||
delta: Optional[bytes] = None
|
||||
result_file_ending: str = ".sfc"
|
||||
source_data: bytes
|
||||
|
||||
def __init__(self, *args, patched_path: str = "", **kwargs):
|
||||
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
|
||||
self.patched_path = patched_path
|
||||
super(APDeltaPatch, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_manifest(self) -> dict:
|
||||
def get_manifest(self) -> Dict[str, Any]:
|
||||
manifest = super(APDeltaPatch, self).get_manifest()
|
||||
manifest["base_checksum"] = self.hash
|
||||
manifest["result_file_ending"] = self.result_file_ending
|
||||
@@ -205,15 +208,19 @@ def generate_yaml(patch: bytes, metadata: Optional[dict] = None, game: str = GAM
|
||||
return patch.encode(encoding="utf-8-sig")
|
||||
|
||||
|
||||
def generate_patch(rom: bytes, metadata: Optional[dict] = None, game: str = GAME_ALTTP) -> bytes:
|
||||
def generate_patch(rom: bytes, metadata: Optional[Dict[str, Any]] = None, game: str = GAME_ALTTP) -> bytes:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
patch = bsdiff4.diff(get_base_rom_data(game), rom)
|
||||
return generate_yaml(patch, metadata, game)
|
||||
|
||||
|
||||
def create_patch_file(rom_file_to_patch: str, server: str = "", destination: str = None,
|
||||
player: int = 0, player_name: str = "", game: str = GAME_ALTTP) -> str:
|
||||
def create_patch_file(rom_file_to_patch: str,
|
||||
server: str = "",
|
||||
destination: Optional[str] = None,
|
||||
player: int = 0,
|
||||
player_name: str = "",
|
||||
game: str = GAME_ALTTP) -> str:
|
||||
meta = {"server": server, # allow immediate connection to server in multiworld. Empty string otherwise
|
||||
"player_id": player,
|
||||
"player_name": player_name}
|
||||
@@ -229,19 +236,19 @@ def create_patch_file(rom_file_to_patch: str, server: str = "", destination: str
|
||||
return target
|
||||
|
||||
|
||||
def create_rom_bytes(patch_file: str, ignore_version: bool = False) -> Tuple[dict, str, bytearray]:
|
||||
def create_rom_bytes(patch_file: str, ignore_version: bool = False) -> Tuple[Dict[str, Any], str, bytearray]:
|
||||
data = Utils.parse_yaml(lzma.decompress(load_bytes(patch_file)).decode("utf-8-sig"))
|
||||
game_name = data["game"]
|
||||
if not ignore_version and data["compatible_version"] > current_patch_version:
|
||||
raise RuntimeError("Patch file is incompatible with this patcher, likely an update is required.")
|
||||
patched_data = bsdiff4.patch(get_base_rom_data(game_name), data["patch"])
|
||||
patched_data: bytearray = bsdiff4.patch(get_base_rom_data(game_name), data["patch"])
|
||||
rom_hash = patched_data[int(0x7FC0):int(0x7FD5)]
|
||||
data["meta"]["hash"] = "".join(chr(x) for x in rom_hash)
|
||||
target = os.path.splitext(patch_file)[0] + ".sfc"
|
||||
return data["meta"], target, patched_data
|
||||
|
||||
|
||||
def get_base_rom_data(game: str):
|
||||
def get_base_rom_data(game: str) -> bytes:
|
||||
if game == GAME_ALTTP:
|
||||
from worlds.alttp.Rom import get_base_rom_bytes
|
||||
elif game == "alttp": # old version for A Link to the Past
|
||||
@@ -260,7 +267,7 @@ def get_base_rom_data(game: str):
|
||||
return get_base_rom_bytes()
|
||||
|
||||
|
||||
def create_rom_file(patch_file: str) -> Tuple[dict, str]:
|
||||
def create_rom_file(patch_file: str) -> Tuple[Dict[str, Any], str]:
|
||||
auto_handler = AutoPatchRegister.get_handler(patch_file)
|
||||
if auto_handler:
|
||||
handler: APDeltaPatch = auto_handler(patch_file)
|
||||
@@ -293,7 +300,7 @@ def write_lzma(data: bytes, path: str):
|
||||
f.write(data)
|
||||
|
||||
|
||||
def read_rom(stream, strip_header=True) -> bytearray:
|
||||
def read_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray:
|
||||
"""Reads rom into bytearray and optionally strips off any smc header"""
|
||||
buffer = bytearray(stream.read())
|
||||
if strip_header and len(buffer) % 0x400 == 0x200:
|
||||
@@ -321,7 +328,7 @@ if __name__ == "__main__":
|
||||
elif rom.endswith(".apbp"):
|
||||
print(f"Applying patch {rom}")
|
||||
data, target = create_rom_file(rom)
|
||||
#romfile, adjusted = Utils.get_adjuster_settings(target)
|
||||
# romfile, adjusted = Utils.get_adjuster_settings(target)
|
||||
adjuster_settings = Utils.get_adjuster_settings(GAME_ALTTP)
|
||||
adjusted = False
|
||||
if adjuster_settings:
|
||||
@@ -385,21 +392,9 @@ if __name__ == "__main__":
|
||||
if 'server' in data:
|
||||
Utils.persistent_store("servers", data['hash'], data['server'])
|
||||
print(f"Host is {data['server']}")
|
||||
elif rom.endswith(".apm3"):
|
||||
print(f"Applying patch {rom}")
|
||||
data, target = create_rom_file(rom)
|
||||
print(f"Created rom {target}.")
|
||||
if 'server' in data:
|
||||
Utils.persistent_store("servers", data['hash'], data['server'])
|
||||
print(f"Host is {data['server']}")
|
||||
elif rom.endswith(".apsmz"):
|
||||
print(f"Applying patch {rom}")
|
||||
data, target = create_rom_file(rom)
|
||||
print(f"Created rom {target}.")
|
||||
if 'server' in data:
|
||||
Utils.persistent_store("servers", data['hash'], data['server'])
|
||||
print(f"Host is {data['server']}")
|
||||
elif rom.endswith(".apdkc3"):
|
||||
elif rom.endswith(".apm3") \
|
||||
or rom.endswith(".apsmz") \
|
||||
or rom.endswith(".apdkc3"):
|
||||
print(f"Applying patch {rom}")
|
||||
data, target = create_rom_file(rom)
|
||||
print(f"Created rom {target}.")
|
||||
@@ -410,8 +405,7 @@ if __name__ == "__main__":
|
||||
elif rom.endswith(".zip"):
|
||||
print(f"Updating host in patch files contained in {rom}")
|
||||
|
||||
|
||||
def _handle_zip_file_entry(zfinfo: zipfile.ZipInfo, server: str):
|
||||
def _handle_zip_file_entry(zfinfo: zipfile.ZipInfo, server: str) -> str:
|
||||
data = zfr.read(zfinfo)
|
||||
if zfinfo.filename.endswith(".apbp") or \
|
||||
zfinfo.filename.endswith(".apm3") or \
|
||||
@@ -421,8 +415,7 @@ if __name__ == "__main__":
|
||||
zfw.writestr(zfinfo, data)
|
||||
return zfinfo.filename
|
||||
|
||||
|
||||
futures = []
|
||||
futures: List[concurrent.futures.Future[str]] = []
|
||||
with zipfile.ZipFile(rom, "r") as zfr:
|
||||
updated_zip = os.path.splitext(rom)[0] + "_updated.zip"
|
||||
with zipfile.ZipFile(updated_zip, "w", compression=zipfile.ZIP_DEFLATED,
|
||||
|
Reference in New Issue
Block a user