234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# pylint: disable=W0212
 | 
						|
import asyncio
 | 
						|
import os
 | 
						|
import platform
 | 
						|
import subprocess
 | 
						|
import time
 | 
						|
import traceback
 | 
						|
 | 
						|
from aiohttp import WSMsgType, web
 | 
						|
from worlds._sc2common.bot import logger
 | 
						|
from s2clientprotocol import sc2api_pb2 as sc_pb
 | 
						|
 | 
						|
from .controller import Controller
 | 
						|
from .data import Result, Status
 | 
						|
from .player import BotProcess
 | 
						|
 | 
						|
 | 
						|
class Proxy:
 | 
						|
    """
 | 
						|
    Class for handling communication between sc2 and an external bot.
 | 
						|
    This "middleman" is needed for enforcing time limits, collecting results, and closing things properly.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        controller: Controller,
 | 
						|
        player: BotProcess,
 | 
						|
        proxyport: int,
 | 
						|
        game_time_limit: int = None,
 | 
						|
        realtime: bool = False,
 | 
						|
    ):
 | 
						|
        self.controller = controller
 | 
						|
        self.player = player
 | 
						|
        self.port = proxyport
 | 
						|
        self.timeout_loop = game_time_limit * 22.4 if game_time_limit else None
 | 
						|
        self.realtime = realtime
 | 
						|
        logger.debug(
 | 
						|
            f"Proxy Inited with ctrl {controller}({controller._process._port}), player {player}, proxyport {proxyport}, lim {game_time_limit}"
 | 
						|
        )
 | 
						|
 | 
						|
        self.result = None
 | 
						|
        self.player_id: int = None
 | 
						|
        self.done = False
 | 
						|
 | 
						|
    async def parse_request(self, msg):
 | 
						|
        request = sc_pb.Request()
 | 
						|
        request.ParseFromString(msg.data)
 | 
						|
        if request.HasField("quit"):
 | 
						|
            request = sc_pb.Request(leave_game=sc_pb.RequestLeaveGame())
 | 
						|
        if request.HasField("leave_game"):
 | 
						|
            if self.controller._status == Status.in_game:
 | 
						|
                logger.info(f"Proxy: player {self.player.name}({self.player_id}) surrenders")
 | 
						|
                self.result = {self.player_id: Result.Defeat}
 | 
						|
            elif self.controller._status == Status.ended:
 | 
						|
                await self.get_response()
 | 
						|
        elif request.HasField("join_game") and not request.join_game.HasField("player_name"):
 | 
						|
            request.join_game.player_name = self.player.name
 | 
						|
        await self.controller._ws.send_bytes(request.SerializeToString())
 | 
						|
 | 
						|
    # TODO Catching too general exception Exception (broad-except)
 | 
						|
    # pylint: disable=W0703
 | 
						|
    async def get_response(self):
 | 
						|
        response_bytes = None
 | 
						|
        try:
 | 
						|
            response_bytes = await self.controller._ws.receive_bytes()
 | 
						|
        except TypeError as e:
 | 
						|
            logger.exception("Cannot receive: SC2 Connection already closed.")
 | 
						|
            tb = traceback.format_exc()
 | 
						|
            logger.error(f"Exception {e}: {tb}")
 | 
						|
        except asyncio.CancelledError:
 | 
						|
            logger.info(f"Proxy({self.player.name}), caught receive from sc2")
 | 
						|
            try:
 | 
						|
                x = await self.controller._ws.receive_bytes()
 | 
						|
                if response_bytes is None:
 | 
						|
                    response_bytes = x
 | 
						|
            except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e:
 | 
						|
                logger.exception(f"Exception {e}")
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
        return response_bytes
 | 
						|
 | 
						|
    async def parse_response(self, response_bytes):
 | 
						|
        response = sc_pb.Response()
 | 
						|
        response.ParseFromString(response_bytes)
 | 
						|
 | 
						|
        if not response.HasField("status"):
 | 
						|
            logger.critical("Proxy: RESPONSE HAS NO STATUS {response}")
 | 
						|
        else:
 | 
						|
            new_status = Status(response.status)
 | 
						|
            if new_status != self.controller._status:
 | 
						|
                logger.info(f"Controller({self.player.name}): {self.controller._status}->{new_status}")
 | 
						|
                self.controller._status = new_status
 | 
						|
 | 
						|
        if self.player_id is None:
 | 
						|
            if response.HasField("join_game"):
 | 
						|
                self.player_id = response.join_game.player_id
 | 
						|
                logger.info(f"Proxy({self.player.name}): got join_game for {self.player_id}")
 | 
						|
 | 
						|
        if self.result is None:
 | 
						|
            if response.HasField("observation"):
 | 
						|
                obs: sc_pb.ResponseObservation = response.observation
 | 
						|
                if obs.player_result:
 | 
						|
                    self.result = {pr.player_id: Result(pr.result) for pr in obs.player_result}
 | 
						|
                elif (
 | 
						|
                    self.timeout_loop and obs.HasField("observation") and obs.observation.game_loop > self.timeout_loop
 | 
						|
                ):
 | 
						|
                    self.result = {i: Result.Tie for i in range(1, 3)}
 | 
						|
                    logger.info(f"Proxy({self.player.name}) timing out")
 | 
						|
                    act = [sc_pb.Action(action_chat=sc_pb.ActionChat(message="Proxy: Timing out"))]
 | 
						|
                    await self.controller._execute(action=sc_pb.RequestAction(actions=act))
 | 
						|
        return response
 | 
						|
 | 
						|
    async def get_result(self):
 | 
						|
        try:
 | 
						|
            res = await self.controller.ping()
 | 
						|
            if res.status in {Status.in_game, Status.in_replay, Status.ended}:
 | 
						|
                res = await self.controller._execute(observation=sc_pb.RequestObservation())
 | 
						|
                if res.HasField("observation") and res.observation.player_result:
 | 
						|
                    self.result = {pr.player_id: Result(pr.result) for pr in res.observation.player_result}
 | 
						|
        # pylint: disable=W0703
 | 
						|
        # TODO Catching too general exception Exception (broad-except)
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
 | 
						|
    async def proxy_handler(self, request):
 | 
						|
        bot_ws = web.WebSocketResponse(receive_timeout=30)
 | 
						|
        await bot_ws.prepare(request)
 | 
						|
        try:
 | 
						|
            async for msg in bot_ws:
 | 
						|
                if msg.data is None:
 | 
						|
                    raise TypeError(f"data is None, {msg}")
 | 
						|
                if msg.data and msg.type == WSMsgType.BINARY:
 | 
						|
 | 
						|
                    await self.parse_request(msg)
 | 
						|
 | 
						|
                    response_bytes = await self.get_response()
 | 
						|
                    if response_bytes is None:
 | 
						|
                        raise ConnectionError("Could not get response_bytes")
 | 
						|
 | 
						|
                    new_response = await self.parse_response(response_bytes)
 | 
						|
                    await bot_ws.send_bytes(new_response.SerializeToString())
 | 
						|
 | 
						|
                elif msg.type == WSMsgType.CLOSED:
 | 
						|
                    logger.error("Client shutdown")
 | 
						|
                else:
 | 
						|
                    logger.error("Incorrect message type")
 | 
						|
        # pylint: disable=W0703
 | 
						|
        # TODO Catching too general exception Exception (broad-except)
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
            ignored_errors = {ConnectionError, asyncio.CancelledError}
 | 
						|
            if not any(isinstance(e, E) for E in ignored_errors):
 | 
						|
                tb = traceback.format_exc()
 | 
						|
                logger.info(f"Proxy({self.player.name}): Caught {e} traceback: {tb}")
 | 
						|
        finally:
 | 
						|
            try:
 | 
						|
                if self.controller._status in {Status.in_game, Status.in_replay}:
 | 
						|
                    await self.controller._execute(leave_game=sc_pb.RequestLeaveGame())
 | 
						|
                await bot_ws.close()
 | 
						|
            # pylint: disable=W0703
 | 
						|
            # TODO Catching too general exception Exception (broad-except)
 | 
						|
            except Exception as e:
 | 
						|
                logger.exception(f"Caught unknown exception during surrender: {e}")
 | 
						|
            self.done = True
 | 
						|
        return bot_ws
 | 
						|
 | 
						|
    # pylint: disable=R0912
 | 
						|
    async def play_with_proxy(self, startport):
 | 
						|
        logger.info(f"Proxy({self.port}): Starting app")
 | 
						|
        app = web.Application()
 | 
						|
        app.router.add_route("GET", "/sc2api", self.proxy_handler)
 | 
						|
        apprunner = web.AppRunner(app, access_log=None)
 | 
						|
        await apprunner.setup()
 | 
						|
        appsite = web.TCPSite(apprunner, self.controller._process._host, self.port)
 | 
						|
        await appsite.start()
 | 
						|
 | 
						|
        subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT}
 | 
						|
        if platform.system() == "Linux":
 | 
						|
            subproc_args["preexec_fn"] = os.setpgrp
 | 
						|
        elif platform.system() == "Windows":
 | 
						|
            subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
 | 
						|
 | 
						|
        player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime)
 | 
						|
        logger.info(f"Starting bot with command: {' '.join(player_command_line)}")
 | 
						|
        if self.player.stdout is None:
 | 
						|
            bot_process = subprocess.Popen(player_command_line, stdout=subprocess.DEVNULL, **subproc_args)
 | 
						|
        else:
 | 
						|
            with open(self.player.stdout, "w+") as out:
 | 
						|
                bot_process = subprocess.Popen(player_command_line, stdout=out, **subproc_args)
 | 
						|
 | 
						|
        while self.result is None:
 | 
						|
            bot_alive = bot_process and bot_process.poll() is None
 | 
						|
            sc2_alive = self.controller.running
 | 
						|
            if self.done or not (bot_alive and sc2_alive):
 | 
						|
                logger.info(
 | 
						|
                    f"Proxy({self.port}): {self.player.name} died, "
 | 
						|
                    f"bot{(not bot_alive) * ' not'} alive, sc2{(not sc2_alive) * ' not'} alive"
 | 
						|
                )
 | 
						|
                # Maybe its still possible to retrieve a result
 | 
						|
                if sc2_alive and not self.done:
 | 
						|
                    await self.get_response()
 | 
						|
                logger.info(f"Proxy({self.port}): breaking, result {self.result}")
 | 
						|
                break
 | 
						|
            await asyncio.sleep(5)
 | 
						|
 | 
						|
        # cleanup
 | 
						|
        logger.info(f"({self.port}): cleaning up {self.player !r}")
 | 
						|
        for _i in range(3):
 | 
						|
            if isinstance(bot_process, subprocess.Popen):
 | 
						|
                if bot_process.stdout and not bot_process.stdout.closed:  # should not run anymore
 | 
						|
                    logger.info(f"==================output for player {self.player.name}")
 | 
						|
                    for l in bot_process.stdout.readlines():
 | 
						|
                        logger.opt(raw=True).info(l.decode("utf-8"))
 | 
						|
                    bot_process.stdout.close()
 | 
						|
                    logger.info("==================")
 | 
						|
                bot_process.terminate()
 | 
						|
                bot_process.wait()
 | 
						|
            time.sleep(0.5)
 | 
						|
            if not bot_process or bot_process.poll() is not None:
 | 
						|
                break
 | 
						|
        else:
 | 
						|
            bot_process.terminate()
 | 
						|
            bot_process.wait()
 | 
						|
        try:
 | 
						|
            await apprunner.cleanup()
 | 
						|
        # pylint: disable=W0703
 | 
						|
        # TODO Catching too general exception Exception (broad-except)
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception during cleaning: {e}")
 | 
						|
        if isinstance(self.result, dict):
 | 
						|
            self.result[None] = None
 | 
						|
            return self.result[self.player_id]
 | 
						|
        return self.result
 |