diff --git a/MultiServer.py b/MultiServer.py index 9f04b9ba..6009a2a5 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -9,6 +9,7 @@ import zlib import collections import typing import inspect +import weakref import ModuleUpdate @@ -41,6 +42,13 @@ class Client: self.tags = [] self.version = [0, 0, 0] self.messageprocessor = ClientMessageProcessor(ctx, self) + self.ctx = weakref.ref(ctx) + + async def disconnect(self): + ctx = self.ctx() + if ctx: + await on_client_disconnected(ctx, self) + ctx.clients.remove(self) @property def wants_item_notification(self): @@ -96,20 +104,25 @@ class Context: f'for {len(received_items)} players') -async def send_msgs(websocket, msgs): +async def send_msgs(client: Client, msgs): + websocket = client.socket if not websocket or not websocket.open or websocket.closed: return - await websocket.send(json.dumps(msgs)) + try: + await websocket.send(json.dumps(msgs)) + except websockets.ConnectionClosed: + logging.exception("Exception during send_msgs") + await client.disconnect() def broadcast_all(ctx : Context, msgs): for client in ctx.clients: if client.auth: - asyncio.create_task(send_msgs(client.socket, msgs)) + asyncio.create_task(send_msgs(client, msgs)) def broadcast_team(ctx : Context, team, msgs): for client in ctx.clients: if client.auth and client.team == team: - asyncio.create_task(send_msgs(client.socket, msgs)) + asyncio.create_task(send_msgs(client, msgs)) def notify_all(ctx : Context, text): logging.info("Notice (all): %s" % text) @@ -125,7 +138,7 @@ def notify_client(client: Client, text: str): if not client.auth: return logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) - asyncio.create_task(send_msgs(client.socket, [['Print', text]])) + asyncio.create_task(send_msgs(client, [['Print', text]])) # separated out, due to compatibilty between clients @@ -140,7 +153,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]): payload = cmd else: payload = texts - asyncio.create_task(send_msgs(client.socket, payload)) + asyncio.create_task(send_msgs(client, payload)) async def server(websocket, path, ctx: Context): client = Client(websocket, ctx) @@ -161,11 +174,10 @@ async def server(websocket, path, ctx: Context): if not isinstance(e, websockets.WebSocketException): logging.exception(e) finally: - await on_client_disconnected(ctx, client) - ctx.clients.remove(client) + await client.disconnect() async def on_client_connected(ctx: Context, client: Client): - await send_msgs(client.socket, [['RoomInfo', { + await send_msgs(client, [['RoomInfo', { 'password': ctx.password is not None, 'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth], # tags are for additional features in the communication. @@ -230,7 +242,8 @@ def send_new_items(ctx: Context): continue items = get_received_items(ctx, client.team, client.slot) if len(items) > client.send_index: - asyncio.create_task(send_msgs(client.socket, [['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]])) + asyncio.create_task(send_msgs(client, [ + ['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]])) client.send_index = len(items) @@ -267,7 +280,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations): for client in ctx.clients: if client.team == team and client.wants_item_notification: asyncio.create_task( - send_msgs(client.socket, [['ItemFound', (target_item, location, slot)]])) + send_msgs(client, [['ItemFound', (target_item, location, slot)]])) ctx.location_checks[team, slot] |= set(locations) send_new_items(ctx) @@ -522,14 +535,14 @@ class ClientMessageProcessor(CommandProcessor): async def process_client_cmd(ctx: Context, client: Client, cmd, args): if type(cmd) is not str: - await send_msgs(client.socket, [['InvalidCmd']]) + await send_msgs(client, [['InvalidCmd']]) return if cmd == 'Connect': if not args or type(args) is not dict or \ 'password' not in args or type(args['password']) not in [str, type(None)] or \ 'rom' not in args or type(args['rom']) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'Connect']]) + await send_msgs(client, [['InvalidArguments', 'Connect']]) return errors = set() @@ -548,7 +561,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args): client.slot = slot if errors: - await send_msgs(client.socket, [['ConnectionRefused', list(errors)]]) + await send_msgs(client, [['ConnectionRefused', list(errors)]]) else: client.auth = True client.version = args.get('version', Client.version) @@ -559,7 +572,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args): if items: reply.append(['ReceivedItems', (0, tuplize_received_items(items))]) client.send_index = len(items) - await send_msgs(client.socket, reply) + await send_msgs(client, reply) await on_client_joined(ctx, client) if client.auth: @@ -567,22 +580,22 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args): items = get_received_items(ctx, client.team, client.slot) if items: client.send_index = len(items) - await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]]) + await send_msgs(client, [['ReceivedItems', (0, tuplize_received_items(items))]]) elif cmd == 'LocationChecks': if type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']]) + await send_msgs(client, [['InvalidArguments', 'LocationChecks']]) return register_location_checks(ctx, client.team, client.slot, args) elif cmd == 'LocationScouts': if type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) + await send_msgs(client, [['InvalidArguments', 'LocationScouts']]) return locs = [] for location in args: if type(location) is not int or 0 >= location > len(Regions.location_table): - await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) + await send_msgs(client, [['InvalidArguments', 'LocationScouts']]) return loc_name = list(Regions.location_table.keys())[location - 1] target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)] @@ -595,17 +608,17 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args): locs.append([loc_name, location, target_item, target_player]) # logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}") - await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]]) + await send_msgs(client, [['LocationInfo', [l[1:] for l in locs]]]) elif cmd == 'UpdateTags': if not args or type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']]) + await send_msgs(client, [['InvalidArguments', 'UpdateTags']]) return client.tags = args if cmd == 'Say': if type(args) is not str or not args.isprintable(): - await send_msgs(client.socket, [['InvalidArguments', 'Say']]) + await send_msgs(client, [['InvalidArguments', 'Say']]) return notify_all(ctx, client.name + ': ' + args)