From 7b0b243607f1b6396995c9ce0a8e3ea4f9cebd66 Mon Sep 17 00:00:00 2001 From: Fabian Dill Date: Sun, 28 Nov 2021 04:06:30 +0100 Subject: [PATCH] MultiServer: remove promp_toolkit --- CommonClient.py | 16 ++-------------- FactorioClient.py | 1 + MultiServer.py | 29 ++++++++++++++++++----------- SNIClient.py | 5 ----- Utils.py | 34 +++++++++++++++++++++++----------- WebHostLib/customserver.py | 2 +- requirements.txt | 1 - 7 files changed, 45 insertions(+), 43 deletions(-) diff --git a/CommonClient.py b/CommonClient.py index 53f7d675..cd21aa29 100644 --- a/CommonClient.py +++ b/CommonClient.py @@ -15,7 +15,7 @@ if __name__ == "__main__": from MultiServer import CommandProcessor from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission -from Utils import Version +from Utils import Version, stream_input from worlds import network_data_package, AutoWorldRegister logger = logging.getLogger("Client") @@ -540,18 +540,6 @@ async def process_server_cmd(ctx: CommonContext, args: dict): ctx.on_package(cmd, args) -def stream_input(stream, queue): - def queuer(): - text = stream.readline().strip() - if text: - queue.put_nowait(text) - - from threading import Thread - thread = Thread(target=queuer, name=f"Stream handler for {stream.name}", daemon=True) - thread.start() - return thread - - async def console_loop(ctx: CommonContext): import sys commandprocessor = ctx.command_processor(ctx) @@ -560,7 +548,7 @@ async def console_loop(ctx: CommonContext): while not ctx.exit_event.is_set(): try: input_text = await queue.get() - input_text = input_text.strip() + queue.task_done() if ctx.input_requests > 0: ctx.input_requests -= 1 diff --git a/FactorioClient.py b/FactorioClient.py index 5c4d9814..4e88ac9a 100644 --- a/FactorioClient.py +++ b/FactorioClient.py @@ -191,6 +191,7 @@ async def factorio_server_watcher(ctx: FactorioContext): while not factorio_queue.empty(): msg = factorio_queue.get() + factorio_queue.task_done() factorio_server_logger.info(msg) if not ctx.rcon_client and "Starting RCON interface at IP ADDR:" in msg: ctx.rcon_client = factorio_rcon.RCONClient("localhost", rcon_port, rcon_password) diff --git a/MultiServer.py b/MultiServer.py index 726a21f1..5e953e6a 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -119,7 +119,7 @@ class Context: self.remaining_mode: str = remaining_mode self.collect_mode: str = collect_mode self.item_cheat = item_cheat - self.running = True + self.exit_event = asyncio.Event() self.client_activity_timers: typing.Dict[ team_slot, datetime.datetime] = {} # datetime of last new item check self.client_connection_timers: typing.Dict[ @@ -336,7 +336,7 @@ class Context: if not self.auto_saver_thread: def save_regularly(): import time - while self.running: + while not self.exit_event.is_set(): time.sleep(self.auto_save_interval) if self.save_dirty: logging.debug("Saving via thread.") @@ -1409,7 +1409,7 @@ class ServerCommandProcessor(CommonCommandProcessor): asyncio.create_task(self.ctx.server.ws_server._close()) if self.ctx.shutdown_task: self.ctx.shutdown_task.cancel() - self.ctx.running = False + self.ctx.exit_event.set() return True @mark_raw @@ -1566,11 +1566,17 @@ class ServerCommandProcessor(CommonCommandProcessor): async def console(ctx: Context): - session = prompt_toolkit.PromptSession() - while ctx.running: - with patch_stdout(): - input_text = await session.prompt_async() + import sys + queue = asyncio.Queue() + Utils.stream_input(sys.stdin, queue) + while not ctx.exit_event.is_set(): try: + # I don't get why this while loop is needed. Works fine without it on clients, + # but the queue.get() for server never fulfills if the queue is empty when entering the await. + while queue.qsize() == 0: + await asyncio.sleep(0.05) + input_text = await queue.get() + queue.task_done() ctx.commandprocessor(input_text) except: import traceback @@ -1636,10 +1642,10 @@ def parse_args() -> argparse.Namespace: async def auto_shutdown(ctx, to_cancel=None): await asyncio.sleep(ctx.auto_shutdown) - while ctx.running: + while not ctx.exit_event.is_set(): if not ctx.client_activity_timers.values(): asyncio.create_task(ctx.server.ws_server._close()) - ctx.running = False + ctx.exit_event.set() if to_cancel: for task in to_cancel: task.cancel() @@ -1650,7 +1656,7 @@ async def auto_shutdown(ctx, to_cancel=None): seconds = ctx.auto_shutdown - delta.total_seconds() if seconds < 0: asyncio.create_task(ctx.server.ws_server._close()) - ctx.running = False + ctx.exit_event.set() if to_cancel: for task in to_cancel: task.cancel() @@ -1694,7 +1700,8 @@ async def main(args: argparse.Namespace): console_task = asyncio.create_task(console(ctx)) if ctx.auto_shutdown: ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, [console_task])) - await console_task + await ctx.exit_event.wait() + console_task.cancel() if ctx.shutdown_task: await ctx.shutdown_task diff --git a/SNIClient.py b/SNIClient.py index 10225f73..61262afe 100644 --- a/SNIClient.py +++ b/SNIClient.py @@ -680,11 +680,6 @@ async def snes_disconnect(ctx: Context): async def snes_autoreconnect(ctx: Context): - # unfortunately currently broken. See: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/1033 - # with prompt_toolkit.shortcuts.ProgressBar() as pb: - # for _ in pb(range(100)): - # await asyncio.sleep(RECONNECT_DELAY/100) - await asyncio.sleep(SNES_RECONNECT_DELAY) if ctx.snes_reconnect_address and ctx.snes_socket is None: await snes_connect(ctx, ctx.snes_reconnect_address) diff --git a/Utils.py b/Utils.py index 84e4378f..29491c62 100644 --- a/Utils.py +++ b/Utils.py @@ -1,6 +1,16 @@ from __future__ import annotations import typing +import builtins +import os +import subprocess +import sys +import pickle +import functools +import io +import collections +import importlib +import logging def tuplize_version(version: str) -> Version: @@ -16,17 +26,6 @@ class Version(typing.NamedTuple): __version__ = "0.2.0" version_tuple = tuplize_version(__version__) -import builtins -import os -import subprocess -import sys -import pickle -import functools -import io -import collections -import importlib -import logging - from yaml import load, dump, safe_load try: @@ -462,3 +461,16 @@ def init_logging(name: str, loglevel: typing.Union[str, int] = logging.INFO, wri handle_exception._wrapped = True sys.excepthook = handle_exception + + +def stream_input(stream, queue): + def queuer(): + while 1: + text = stream.readline().strip() + if text: + queue.put_nowait(text) + + from threading import Thread + thread = Thread(target=queuer, name=f"Stream handler for {stream.name}", daemon=True) + thread.start() + return thread diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 20d1d784..4ddf01ed 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -56,7 +56,7 @@ class WebHostContext(Context): def listen_to_db_commands(self): cmdprocessor = DBCommandProcessor(self) - while self.running: + while not self.exit_event.is_set(): with db_session: commands = select(command for command in Command if command.room.id == self.room_id) if commands: diff --git a/requirements.txt b/requirements.txt index c32a5185..8a35b905 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ colorama>=0.4.4 websockets>=10.1 PyYAML>=6.0 fuzzywuzzy>=0.18.0 -prompt_toolkit>=3.0.23 appdirs>=1.4.4 jinja2>=3.0.3 schema>=0.7.4