mirror of
				https://github.com/MarioSpore/Grinch-AP.git
				synced 2025-10-21 20:21:32 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			356 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			356 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import asyncio
 | |
| import collections
 | |
| import datetime
 | |
| import functools
 | |
| import logging
 | |
| import multiprocessing
 | |
| import pickle
 | |
| import random
 | |
| import socket
 | |
| import threading
 | |
| import time
 | |
| import typing
 | |
| import sys
 | |
| 
 | |
| import websockets
 | |
| from pony.orm import commit, db_session, select
 | |
| 
 | |
| import Utils
 | |
| 
 | |
| from MultiServer import Context, server, auto_shutdown, ServerCommandProcessor, ClientMessageProcessor, load_server_cert
 | |
| from Utils import restricted_loads, cache_argsless
 | |
| from .locker import Locker
 | |
| from .models import Command, GameDataPackage, Room, db
 | |
| 
 | |
| 
 | |
| class CustomClientMessageProcessor(ClientMessageProcessor):
 | |
|     ctx: WebHostContext
 | |
| 
 | |
|     def _cmd_video(self, platform: str, user: str):
 | |
|         """Set a link for your name in the WebHostLib tracker pointing to a video stream.
 | |
|         Currently, only YouTube and Twitch platforms are supported.
 | |
|         """
 | |
|         if platform.lower().startswith("t"):  # twitch
 | |
|             self.ctx.video[self.client.team, self.client.slot] = "Twitch", user
 | |
|             self.ctx.save()
 | |
|             self.output(f"Registered Twitch Stream https://www.twitch.tv/{user}")
 | |
|             return True
 | |
|         elif platform.lower().startswith("y"):  # youtube
 | |
|             self.ctx.video[self.client.team, self.client.slot] = "Youtube", user
 | |
|             self.ctx.save()
 | |
|             self.output(f"Registered Youtube Stream for {user}")
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
| 
 | |
| # inject
 | |
| import MultiServer
 | |
| 
 | |
| MultiServer.client_message_processor = CustomClientMessageProcessor
 | |
| del MultiServer
 | |
| 
 | |
| 
 | |
| class DBCommandProcessor(ServerCommandProcessor):
 | |
|     def output(self, text: str):
 | |
|         self.ctx.logger.info(text)
 | |
| 
 | |
| 
 | |
| class WebHostContext(Context):
 | |
|     room_id: int
 | |
| 
 | |
|     def __init__(self, static_server_data: dict, logger: logging.Logger):
 | |
|         # static server data is used during _load_game_data to load required data,
 | |
|         # without needing to import worlds system, which takes quite a bit of memory
 | |
|         self.static_server_data = static_server_data
 | |
|         super(WebHostContext, self).__init__("", 0, "", "", 1,
 | |
|                                              40, True, "enabled", "enabled",
 | |
|                                              "enabled", 0, 2, logger=logger)
 | |
|         del self.static_server_data
 | |
|         self.main_loop = asyncio.get_running_loop()
 | |
|         self.video = {}
 | |
|         self.tags = ["AP", "WebHost"]
 | |
| 
 | |
|     def __del__(self):
 | |
|         try:
 | |
|             import psutil
 | |
|             from Utils import format_SI_prefix
 | |
|             self.logger.debug(f"Context destroyed, Mem: {format_SI_prefix(psutil.Process().memory_info().rss, 1024)}iB")
 | |
|         except ImportError:
 | |
|             self.logger.debug("Context destroyed")
 | |
| 
 | |
|     def _load_game_data(self):
 | |
|         for key, value in self.static_server_data.items():
 | |
|             # NOTE: attributes are mutable and shared, so they will have to be copied before being modified
 | |
|             setattr(self, key, value)
 | |
|         self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names)
 | |
| 
 | |
|     def listen_to_db_commands(self):
 | |
|         cmdprocessor = DBCommandProcessor(self)
 | |
| 
 | |
|         while not self.exit_event.is_set():
 | |
|             with db_session:
 | |
|                 commands = select(command for command in Command if command.room.id == self.room_id)
 | |
|                 if commands:
 | |
|                     for command in commands:
 | |
|                         self.main_loop.call_soon_threadsafe(cmdprocessor, command.commandtext)
 | |
|                         command.delete()
 | |
|                     commit()
 | |
|             time.sleep(5)
 | |
| 
 | |
|     @db_session
 | |
|     def load(self, room_id: int):
 | |
|         self.room_id = room_id
 | |
|         room = Room.get(id=room_id)
 | |
|         if room.last_port:
 | |
|             self.port = room.last_port
 | |
|         else:
 | |
|             self.port = get_random_port()
 | |
| 
 | |
|         multidata = self.decompress(room.seed.multidata)
 | |
|         game_data_packages = {}
 | |
| 
 | |
|         static_gamespackage = self.gamespackage  # this is shared across all rooms
 | |
|         static_item_name_groups = self.item_name_groups
 | |
|         static_location_name_groups = self.location_name_groups
 | |
|         self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})}  # this may be modified by _load
 | |
|         self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})}
 | |
|         self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})}
 | |
| 
 | |
|         for game in list(multidata.get("datapackage", {})):
 | |
|             game_data = multidata["datapackage"][game]
 | |
|             if "checksum" in game_data:
 | |
|                 if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]:
 | |
|                     # non-custom. remove from multidata and use static data
 | |
|                     # games package could be dropped from static data once all rooms embed data package
 | |
|                     del multidata["datapackage"][game]
 | |
|                 else:
 | |
|                     row = GameDataPackage.get(checksum=game_data["checksum"])
 | |
|                     if row:  # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete
 | |
|                         game_data_packages[game] = Utils.restricted_loads(row.data)
 | |
|                         continue
 | |
|                     else:
 | |
|                         self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}")
 | |
|             self.gamespackage[game] = static_gamespackage.get(game, {})
 | |
|             self.item_name_groups[game] = static_item_name_groups.get(game, {})
 | |
|             self.location_name_groups[game] = static_location_name_groups.get(game, {})
 | |
| 
 | |
|         if not game_data_packages:
 | |
|             # all static -> use the static dicts directly
 | |
|             self.gamespackage = static_gamespackage
 | |
|             self.item_name_groups = static_item_name_groups
 | |
|             self.location_name_groups = static_location_name_groups
 | |
|         return self._load(multidata, game_data_packages, True)
 | |
| 
 | |
|     @db_session
 | |
|     def init_save(self, enabled: bool = True):
 | |
|         self.saving = enabled
 | |
|         if self.saving:
 | |
|             savegame_data = Room.get(id=self.room_id).multisave
 | |
|             if savegame_data:
 | |
|                 self.set_save(restricted_loads(Room.get(id=self.room_id).multisave))
 | |
|             self._start_async_saving(atexit_save=False)
 | |
|         threading.Thread(target=self.listen_to_db_commands, daemon=True).start()
 | |
| 
 | |
|     @db_session
 | |
|     def _save(self, exit_save: bool = False) -> bool:
 | |
|         room = Room.get(id=self.room_id)
 | |
|         room.multisave = pickle.dumps(self.get_save())
 | |
|         # saving only occurs on activity, so we can "abuse" this information to mark this as last_activity
 | |
|         if not exit_save:  # we don't want to count a shutdown as activity, which would restart the server again
 | |
|             room.last_activity = datetime.datetime.utcnow()
 | |
|         return True
 | |
| 
 | |
|     def get_save(self) -> dict:
 | |
|         d = super(WebHostContext, self).get_save()
 | |
|         d["video"] = [(tuple(playerslot), videodata) for playerslot, videodata in self.video.items()]
 | |
|         return d
 | |
| 
 | |
| 
 | |
| def get_random_port():
 | |
|     return random.randint(49152, 65535)
 | |
| 
 | |
| 
 | |
| @cache_argsless
 | |
| def get_static_server_data() -> dict:
 | |
|     import worlds
 | |
|     data = {
 | |
|         "non_hintable_names": {
 | |
|             world_name: world.hint_blacklist
 | |
|             for world_name, world in worlds.AutoWorldRegister.world_types.items()
 | |
|         },
 | |
|         "gamespackage": {
 | |
|             world_name: {
 | |
|                 key: value
 | |
|                 for key, value in game_package.items()
 | |
|                 if key not in ("item_name_groups", "location_name_groups")
 | |
|             }
 | |
|             for world_name, game_package in worlds.network_data_package["games"].items()
 | |
|         },
 | |
|         "item_name_groups": {
 | |
|             world_name: world.item_name_groups
 | |
|             for world_name, world in worlds.AutoWorldRegister.world_types.items()
 | |
|         },
 | |
|         "location_name_groups": {
 | |
|             world_name: world.location_name_groups
 | |
|             for world_name, world in worlds.AutoWorldRegister.world_types.items()
 | |
|         },
 | |
|     }
 | |
| 
 | |
|     return data
 | |
| 
 | |
| 
 | |
| def set_up_logging(room_id) -> logging.Logger:
 | |
|     import os
 | |
|     # logger setup
 | |
|     logger = logging.getLogger(f"RoomLogger {room_id}")
 | |
| 
 | |
|     # this *should* be empty, but just in case.
 | |
|     for handler in logger.handlers[:]:
 | |
|         logger.removeHandler(handler)
 | |
|         handler.close()
 | |
| 
 | |
|     file_handler = logging.FileHandler(
 | |
|         os.path.join(Utils.user_path("logs"), f"{room_id}.txt"),
 | |
|         "a",
 | |
|         encoding="utf-8-sig")
 | |
|     file_handler.setFormatter(logging.Formatter("[%(asctime)s]: %(message)s"))
 | |
|     logger.setLevel(logging.INFO)
 | |
|     logger.addHandler(file_handler)
 | |
|     return logger
 | |
| 
 | |
| 
 | |
| def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
 | |
|                        cert_file: typing.Optional[str], cert_key_file: typing.Optional[str],
 | |
|                        host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue):
 | |
|     Utils.init_logging(name)
 | |
|     try:
 | |
|         import resource
 | |
|     except ModuleNotFoundError:
 | |
|         pass  # unix only module
 | |
|     else:
 | |
|         # Each Server is another file handle, so request as many as we can from the system
 | |
|         file_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
 | |
|         # set soft limit to hard limit
 | |
|         resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
 | |
|         del resource, file_limit
 | |
| 
 | |
|     # establish DB connection for multidata and multisave
 | |
|     db.bind(**ponyconfig)
 | |
|     db.generate_mapping(check_tables=False)
 | |
| 
 | |
|     if "worlds" in sys.modules:
 | |
|         raise Exception("Worlds system should not be loaded in the custom server.")
 | |
| 
 | |
|     import gc
 | |
|     ssl_context = load_server_cert(cert_file, cert_key_file) if cert_file else None
 | |
|     del cert_file, cert_key_file, ponyconfig
 | |
|     gc.collect()  # free intermediate objects used during setup
 | |
| 
 | |
|     loop = asyncio.get_event_loop()
 | |
| 
 | |
|     async def start_room(room_id):
 | |
|         with Locker(f"RoomLocker {room_id}"):
 | |
|             try:
 | |
|                 logger = set_up_logging(room_id)
 | |
|                 ctx = WebHostContext(static_server_data, logger)
 | |
|                 ctx.load(room_id)
 | |
|                 ctx.init_save()
 | |
|                 assert ctx.server is None
 | |
|                 try:
 | |
|                     ctx.server = websockets.serve(
 | |
|                         functools.partial(server, ctx=ctx), ctx.host, ctx.port, ssl=ssl_context)
 | |
| 
 | |
|                     await ctx.server
 | |
|                 except OSError:  # likely port in use
 | |
|                     ctx.server = websockets.serve(
 | |
|                         functools.partial(server, ctx=ctx), ctx.host, 0, ssl=ssl_context)
 | |
| 
 | |
|                     await ctx.server
 | |
|                 port = 0
 | |
|                 for wssocket in ctx.server.ws_server.sockets:
 | |
|                     socketname = wssocket.getsockname()
 | |
|                     if wssocket.family == socket.AF_INET6:
 | |
|                         # Prefer IPv4, as most users seem to not have working ipv6 support
 | |
|                         if not port:
 | |
|                             port = socketname[1]
 | |
|                     elif wssocket.family == socket.AF_INET:
 | |
|                         port = socketname[1]
 | |
|                 if port:
 | |
|                     ctx.logger.info(f'Hosting game at {host}:{port}')
 | |
|                     with db_session:
 | |
|                         room = Room.get(id=ctx.room_id)
 | |
|                         room.last_port = port
 | |
|                 else:
 | |
|                     ctx.logger.exception("Could not determine port. Likely hosting failure.")
 | |
|                 with db_session:
 | |
|                     ctx.auto_shutdown = Room.get(id=room_id).timeout
 | |
|                 if ctx.saving:
 | |
|                     setattr(asyncio.current_task(), "save", lambda: ctx._save(True))
 | |
|                 assert ctx.shutdown_task is None
 | |
|                 ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, []))
 | |
|                 await ctx.shutdown_task
 | |
| 
 | |
|             except (KeyboardInterrupt, SystemExit):
 | |
|                 if ctx.saving:
 | |
|                     ctx._save()
 | |
|                     setattr(asyncio.current_task(), "save", None)
 | |
|             except Exception as e:
 | |
|                 with db_session:
 | |
|                     room = Room.get(id=room_id)
 | |
|                     room.last_port = -1
 | |
|                 logger.exception(e)
 | |
|                 raise
 | |
|             else:
 | |
|                 if ctx.saving:
 | |
|                     ctx._save()
 | |
|                     setattr(asyncio.current_task(), "save", None)
 | |
|             finally:
 | |
|                 try:
 | |
|                     ctx.save_dirty = False  # make sure the saving thread does not write to DB after final wakeup
 | |
|                     ctx.exit_event.set()  # make sure the saving thread stops at some point
 | |
|                     # NOTE: async saving should probably be an async task and could be merged with shutdown_task
 | |
|                     with (db_session):
 | |
|                         # ensure the Room does not spin up again on its own, minute of safety buffer
 | |
|                         room = Room.get(id=room_id)
 | |
|                         room.last_activity = datetime.datetime.utcnow() - \
 | |
|                                              datetime.timedelta(minutes=1, seconds=room.timeout)
 | |
|                     logging.info(f"Shutting down room {room_id} on {name}.")
 | |
|                 finally:
 | |
|                     await asyncio.sleep(5)
 | |
|                     rooms_shutting_down.put(room_id)
 | |
| 
 | |
|     class Starter(threading.Thread):
 | |
|         _tasks: typing.List[asyncio.Future]
 | |
| 
 | |
|         def __init__(self):
 | |
|             super().__init__()
 | |
|             self._tasks = []
 | |
| 
 | |
|         def _done(self, task: asyncio.Future):
 | |
|             self._tasks.remove(task)
 | |
|             task.result()
 | |
| 
 | |
|         def run(self):
 | |
|             while 1:
 | |
|                 next_room = rooms_to_run.get(block=True,  timeout=None)
 | |
|                 gc.collect()
 | |
|                 task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop)
 | |
|                 self._tasks.append(task)
 | |
|                 task.add_done_callback(self._done)
 | |
|                 logging.info(f"Starting room {next_room} on {name}.")
 | |
|                 del task  # delete reference to task object
 | |
| 
 | |
|     starter = Starter()
 | |
|     starter.daemon = True
 | |
|     starter.start()
 | |
|     try:
 | |
|         loop.run_forever()
 | |
|     finally:
 | |
|         # save all tasks that want to be saved during shutdown
 | |
|         for task in asyncio.all_tasks(loop):
 | |
|             save: typing.Optional[typing.Callable[[], typing.Any]] = getattr(task, "save", None)
 | |
|             if save:
 | |
|                 save()
 | 
