234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # pylint: disable=W0212
 | |
| import asyncio
 | |
| import os
 | |
| import platform
 | |
| import subprocess
 | |
| import time
 | |
| import traceback
 | |
| 
 | |
| from aiohttp import WSMsgType, web
 | |
| from worlds._sc2common.bot import logger
 | |
| from s2clientprotocol import sc2api_pb2 as sc_pb
 | |
| 
 | |
| from .controller import Controller
 | |
| from .data import Result, Status
 | |
| from .player import BotProcess
 | |
| 
 | |
| 
 | |
| class Proxy:
 | |
|     """
 | |
|     Class for handling communication between sc2 and an external bot.
 | |
|     This "middleman" is needed for enforcing time limits, collecting results, and closing things properly.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         controller: Controller,
 | |
|         player: BotProcess,
 | |
|         proxyport: int,
 | |
|         game_time_limit: int = None,
 | |
|         realtime: bool = False,
 | |
|     ):
 | |
|         self.controller = controller
 | |
|         self.player = player
 | |
|         self.port = proxyport
 | |
|         self.timeout_loop = game_time_limit * 22.4 if game_time_limit else None
 | |
|         self.realtime = realtime
 | |
|         logger.debug(
 | |
|             f"Proxy Inited with ctrl {controller}({controller._process._port}), player {player}, proxyport {proxyport}, lim {game_time_limit}"
 | |
|         )
 | |
| 
 | |
|         self.result = None
 | |
|         self.player_id: int = None
 | |
|         self.done = False
 | |
| 
 | |
|     async def parse_request(self, msg):
 | |
|         request = sc_pb.Request()
 | |
|         request.ParseFromString(msg.data)
 | |
|         if request.HasField("quit"):
 | |
|             request = sc_pb.Request(leave_game=sc_pb.RequestLeaveGame())
 | |
|         if request.HasField("leave_game"):
 | |
|             if self.controller._status == Status.in_game:
 | |
|                 logger.info(f"Proxy: player {self.player.name}({self.player_id}) surrenders")
 | |
|                 self.result = {self.player_id: Result.Defeat}
 | |
|             elif self.controller._status == Status.ended:
 | |
|                 await self.get_response()
 | |
|         elif request.HasField("join_game") and not request.join_game.HasField("player_name"):
 | |
|             request.join_game.player_name = self.player.name
 | |
|         await self.controller._ws.send_bytes(request.SerializeToString())
 | |
| 
 | |
|     # TODO Catching too general exception Exception (broad-except)
 | |
|     # pylint: disable=W0703
 | |
|     async def get_response(self):
 | |
|         response_bytes = None
 | |
|         try:
 | |
|             response_bytes = await self.controller._ws.receive_bytes()
 | |
|         except TypeError as e:
 | |
|             logger.exception("Cannot receive: SC2 Connection already closed.")
 | |
|             tb = traceback.format_exc()
 | |
|             logger.error(f"Exception {e}: {tb}")
 | |
|         except asyncio.CancelledError:
 | |
|             logger.info(f"Proxy({self.player.name}), caught receive from sc2")
 | |
|             try:
 | |
|                 x = await self.controller._ws.receive_bytes()
 | |
|                 if response_bytes is None:
 | |
|                     response_bytes = x
 | |
|             except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e:
 | |
|                 logger.exception(f"Exception {e}")
 | |
|         except Exception as e:
 | |
|             logger.exception(f"Caught unknown exception: {e}")
 | |
|         return response_bytes
 | |
| 
 | |
|     async def parse_response(self, response_bytes):
 | |
|         response = sc_pb.Response()
 | |
|         response.ParseFromString(response_bytes)
 | |
| 
 | |
|         if not response.HasField("status"):
 | |
|             logger.critical("Proxy: RESPONSE HAS NO STATUS {response}")
 | |
|         else:
 | |
|             new_status = Status(response.status)
 | |
|             if new_status != self.controller._status:
 | |
|                 logger.info(f"Controller({self.player.name}): {self.controller._status}->{new_status}")
 | |
|                 self.controller._status = new_status
 | |
| 
 | |
|         if self.player_id is None:
 | |
|             if response.HasField("join_game"):
 | |
|                 self.player_id = response.join_game.player_id
 | |
|                 logger.info(f"Proxy({self.player.name}): got join_game for {self.player_id}")
 | |
| 
 | |
|         if self.result is None:
 | |
|             if response.HasField("observation"):
 | |
|                 obs: sc_pb.ResponseObservation = response.observation
 | |
|                 if obs.player_result:
 | |
|                     self.result = {pr.player_id: Result(pr.result) for pr in obs.player_result}
 | |
|                 elif (
 | |
|                     self.timeout_loop and obs.HasField("observation") and obs.observation.game_loop > self.timeout_loop
 | |
|                 ):
 | |
|                     self.result = {i: Result.Tie for i in range(1, 3)}
 | |
|                     logger.info(f"Proxy({self.player.name}) timing out")
 | |
|                     act = [sc_pb.Action(action_chat=sc_pb.ActionChat(message="Proxy: Timing out"))]
 | |
|                     await self.controller._execute(action=sc_pb.RequestAction(actions=act))
 | |
|         return response
 | |
| 
 | |
|     async def get_result(self):
 | |
|         try:
 | |
|             res = await self.controller.ping()
 | |
|             if res.status in {Status.in_game, Status.in_replay, Status.ended}:
 | |
|                 res = await self.controller._execute(observation=sc_pb.RequestObservation())
 | |
|                 if res.HasField("observation") and res.observation.player_result:
 | |
|                     self.result = {pr.player_id: Result(pr.result) for pr in res.observation.player_result}
 | |
|         # pylint: disable=W0703
 | |
|         # TODO Catching too general exception Exception (broad-except)
 | |
|         except Exception as e:
 | |
|             logger.exception(f"Caught unknown exception: {e}")
 | |
| 
 | |
|     async def proxy_handler(self, request):
 | |
|         bot_ws = web.WebSocketResponse(receive_timeout=30)
 | |
|         await bot_ws.prepare(request)
 | |
|         try:
 | |
|             async for msg in bot_ws:
 | |
|                 if msg.data is None:
 | |
|                     raise TypeError(f"data is None, {msg}")
 | |
|                 if msg.data and msg.type == WSMsgType.BINARY:
 | |
| 
 | |
|                     await self.parse_request(msg)
 | |
| 
 | |
|                     response_bytes = await self.get_response()
 | |
|                     if response_bytes is None:
 | |
|                         raise ConnectionError("Could not get response_bytes")
 | |
| 
 | |
|                     new_response = await self.parse_response(response_bytes)
 | |
|                     await bot_ws.send_bytes(new_response.SerializeToString())
 | |
| 
 | |
|                 elif msg.type == WSMsgType.CLOSED:
 | |
|                     logger.error("Client shutdown")
 | |
|                 else:
 | |
|                     logger.error("Incorrect message type")
 | |
|         # pylint: disable=W0703
 | |
|         # TODO Catching too general exception Exception (broad-except)
 | |
|         except Exception as e:
 | |
|             logger.exception(f"Caught unknown exception: {e}")
 | |
|             ignored_errors = {ConnectionError, asyncio.CancelledError}
 | |
|             if not any(isinstance(e, E) for E in ignored_errors):
 | |
|                 tb = traceback.format_exc()
 | |
|                 logger.info(f"Proxy({self.player.name}): Caught {e} traceback: {tb}")
 | |
|         finally:
 | |
|             try:
 | |
|                 if self.controller._status in {Status.in_game, Status.in_replay}:
 | |
|                     await self.controller._execute(leave_game=sc_pb.RequestLeaveGame())
 | |
|                 await bot_ws.close()
 | |
|             # pylint: disable=W0703
 | |
|             # TODO Catching too general exception Exception (broad-except)
 | |
|             except Exception as e:
 | |
|                 logger.exception(f"Caught unknown exception during surrender: {e}")
 | |
|             self.done = True
 | |
|         return bot_ws
 | |
| 
 | |
|     # pylint: disable=R0912
 | |
|     async def play_with_proxy(self, startport):
 | |
|         logger.info(f"Proxy({self.port}): Starting app")
 | |
|         app = web.Application()
 | |
|         app.router.add_route("GET", "/sc2api", self.proxy_handler)
 | |
|         apprunner = web.AppRunner(app, access_log=None)
 | |
|         await apprunner.setup()
 | |
|         appsite = web.TCPSite(apprunner, self.controller._process._host, self.port)
 | |
|         await appsite.start()
 | |
| 
 | |
|         subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT}
 | |
|         if platform.system() == "Linux":
 | |
|             subproc_args["preexec_fn"] = os.setpgrp
 | |
|         elif platform.system() == "Windows":
 | |
|             subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
 | |
| 
 | |
|         player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime)
 | |
|         logger.info(f"Starting bot with command: {' '.join(player_command_line)}")
 | |
|         if self.player.stdout is None:
 | |
|             bot_process = subprocess.Popen(player_command_line, stdout=subprocess.DEVNULL, **subproc_args)
 | |
|         else:
 | |
|             with open(self.player.stdout, "w+") as out:
 | |
|                 bot_process = subprocess.Popen(player_command_line, stdout=out, **subproc_args)
 | |
| 
 | |
|         while self.result is None:
 | |
|             bot_alive = bot_process and bot_process.poll() is None
 | |
|             sc2_alive = self.controller.running
 | |
|             if self.done or not (bot_alive and sc2_alive):
 | |
|                 logger.info(
 | |
|                     f"Proxy({self.port}): {self.player.name} died, "
 | |
|                     f"bot{(not bot_alive) * ' not'} alive, sc2{(not sc2_alive) * ' not'} alive"
 | |
|                 )
 | |
|                 # Maybe its still possible to retrieve a result
 | |
|                 if sc2_alive and not self.done:
 | |
|                     await self.get_response()
 | |
|                 logger.info(f"Proxy({self.port}): breaking, result {self.result}")
 | |
|                 break
 | |
|             await asyncio.sleep(5)
 | |
| 
 | |
|         # cleanup
 | |
|         logger.info(f"({self.port}): cleaning up {self.player !r}")
 | |
|         for _i in range(3):
 | |
|             if isinstance(bot_process, subprocess.Popen):
 | |
|                 if bot_process.stdout and not bot_process.stdout.closed:  # should not run anymore
 | |
|                     logger.info(f"==================output for player {self.player.name}")
 | |
|                     for l in bot_process.stdout.readlines():
 | |
|                         logger.opt(raw=True).info(l.decode("utf-8"))
 | |
|                     bot_process.stdout.close()
 | |
|                     logger.info("==================")
 | |
|                 bot_process.terminate()
 | |
|                 bot_process.wait()
 | |
|             time.sleep(0.5)
 | |
|             if not bot_process or bot_process.poll() is not None:
 | |
|                 break
 | |
|         else:
 | |
|             bot_process.terminate()
 | |
|             bot_process.wait()
 | |
|         try:
 | |
|             await apprunner.cleanup()
 | |
|         # pylint: disable=W0703
 | |
|         # TODO Catching too general exception Exception (broad-except)
 | |
|         except Exception as e:
 | |
|             logger.exception(f"Caught unknown exception during cleaning: {e}")
 | |
|         if isinstance(self.result, dict):
 | |
|             self.result[None] = None
 | |
|             return self.result[self.player_id]
 | |
|         return self.result
 | 
