diff --git a/MultiServer.py b/MultiServer.py index 1c375ee1..564f7949 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -22,6 +22,9 @@ import ModuleUpdate ModuleUpdate.update() +if typing.TYPE_CHECKING: + import ssl + import websockets import colorama try: @@ -2090,6 +2093,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--password', default=defaults["password"]) parser.add_argument('--savefile', default=defaults["savefile"]) parser.add_argument('--disable_save', default=defaults["disable_save"], action='store_true') + parser.add_argument('--cert', help="Path to a SSL Certificate for encryption.") + parser.add_argument('--cert_key', help="Path to SSL Certificate Key file") parser.add_argument('--loglevel', default=defaults["loglevel"], choices=['debug', 'info', 'warning', 'error', 'critical']) parser.add_argument('--location_check_points', default=defaults["location_check_points"], type=int) @@ -2162,6 +2167,14 @@ async def auto_shutdown(ctx, to_cancel=None): await asyncio.sleep(seconds) +def load_server_cert(path: str, cert_key: typing.Optional[str]) -> "ssl.SSLContext": + import ssl + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_default_certs() + ssl_context.load_cert_chain(path, cert_key if cert_key else path) + return ssl_context + + async def main(args: argparse.Namespace): Utils.init_logging("Server", loglevel=args.loglevel.lower()) @@ -2197,8 +2210,10 @@ async def main(args: argparse.Namespace): ctx.init_save(not args.disable_save) + ssl_context = load_server_cert(args.cert, args.cert_key) if args.cert else None + ctx.server = websockets.serve(functools.partial(server, ctx=ctx), host=ctx.host, port=ctx.port, ping_timeout=None, - ping_interval=None) + ping_interval=None, ssl=ssl_context) ip = args.host if args.host else Utils.get_public_ipv4() logging.info('Hosting game at %s:%d (%s)' % (ip, ctx.port, 'No password' if not ctx.password else 'Password: %s' % ctx.password)) diff --git a/WebHostLib/__init__.py b/WebHostLib/__init__.py index 35813f17..e8e5b59d 100644 --- a/WebHostLib/__init__.py +++ b/WebHostLib/__init__.py @@ -24,6 +24,8 @@ app.jinja_env.filters['all'] = all app.config["SELFHOST"] = True # application process is in charge of running the websites app.config["GENERATORS"] = 8 # maximum concurrent world gens app.config["SELFLAUNCH"] = True # application process is in charge of launching Rooms. +app.config["SELFLAUNCHCERT"] = None # can point to a SSL Certificate to encrypt Room websocket connections +app.config["SELFLAUNCHKEY"] = None # can point to a SSL Certificate Key to encrypt Room websocket connections app.config["SELFGEN"] = True # application process is in charge of scheduling Generations. app.config["DEBUG"] = False app.config["PORT"] = 80 diff --git a/WebHostLib/autolauncher.py b/WebHostLib/autolauncher.py index 8de73ba1..4cf72433 100644 --- a/WebHostLib/autolauncher.py +++ b/WebHostLib/autolauncher.py @@ -177,6 +177,8 @@ class MultiworldInstance(): with guardian_lock: multiworlds[self.room_id] = self self.ponyconfig = config["PONY"] + self.cert = config["SELFLAUNCHCERT"] + self.key = config["SELFLAUNCHKEY"] def start(self): if self.process and self.process.is_alive(): @@ -184,7 +186,8 @@ class MultiworldInstance(): logging.info(f"Spinning up {self.room_id}") process = multiprocessing.Process(group=None, target=run_server_process, - args=(self.room_id, self.ponyconfig, get_static_server_data()), + args=(self.room_id, self.ponyconfig, get_static_server_data(), + self.cert, self.key), name="MultiHost") process.start() # bind after start to prevent thread sync issues with guardian. diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index c445b413..9c21fca4 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -10,14 +10,16 @@ import random import socket import threading import time +import typing import websockets from pony.orm import commit, db_session, select import Utils -from MultiServer import ClientMessageProcessor, Context, ServerCommandProcessor, auto_shutdown, server -from Utils import cache_argsless, get_public_ipv4, get_public_ipv6, restricted_loads -from .models import Command, Room, db + +from MultiServer import Context, server, auto_shutdown, ServerCommandProcessor, ClientMessageProcessor, load_server_cert +from Utils import get_public_ipv4, get_public_ipv6, restricted_loads, cache_argsless +from .models import Room, Command, db class CustomClientMessageProcessor(ClientMessageProcessor): @@ -137,7 +139,8 @@ def get_static_server_data() -> dict: return data -def run_server_process(room_id, ponyconfig: dict, static_server_data: dict): +def run_server_process(room_id, ponyconfig: dict, static_server_data: dict, + cert_file: typing.Optional[str], cert_key_file: typing.Optional[str]): # establish DB connection for multidata and multisave db.bind(**ponyconfig) db.generate_mapping(check_tables=False) @@ -147,15 +150,15 @@ def run_server_process(room_id, ponyconfig: dict, static_server_data: dict): ctx = WebHostContext(static_server_data) ctx.load(room_id) ctx.init_save() - + ssl_context = load_server_cert(cert_file, cert_key_file) if cert_file else None try: ctx.server = websockets.serve(functools.partial(server, ctx=ctx), ctx.host, ctx.port, ping_timeout=None, - ping_interval=None) + ping_interval=None, ssl=ssl_context) await ctx.server except Exception: # likely port in use - in windows this is OSError, but I didn't check the others ctx.server = websockets.serve(functools.partial(server, ctx=ctx), ctx.host, 0, ping_timeout=None, - ping_interval=None) + ping_interval=None, ssl=ssl_context) await ctx.server port = 0