mirror of
				https://github.com/MarioSpore/Grinch-AP.git
				synced 2025-10-21 20:21:32 -06:00 
			
		
		
		
	
		
			
	
	
		
			234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			234 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # pylint: disable=W0212 | ||
|  | import asyncio | ||
|  | import os | ||
|  | import platform | ||
|  | import subprocess | ||
|  | import time | ||
|  | import traceback | ||
|  | 
 | ||
|  | from aiohttp import WSMsgType, web | ||
|  | from worlds._sc2common.bot import logger | ||
|  | from s2clientprotocol import sc2api_pb2 as sc_pb | ||
|  | 
 | ||
|  | from .controller import Controller | ||
|  | from .data import Result, Status | ||
|  | from .player import BotProcess | ||
|  | 
 | ||
|  | 
 | ||
|  | class Proxy: | ||
|  |     """
 | ||
|  |     Class for handling communication between sc2 and an external bot. | ||
|  |     This "middleman" is needed for enforcing time limits, collecting results, and closing things properly. | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def __init__( | ||
|  |         self, | ||
|  |         controller: Controller, | ||
|  |         player: BotProcess, | ||
|  |         proxyport: int, | ||
|  |         game_time_limit: int = None, | ||
|  |         realtime: bool = False, | ||
|  |     ): | ||
|  |         self.controller = controller | ||
|  |         self.player = player | ||
|  |         self.port = proxyport | ||
|  |         self.timeout_loop = game_time_limit * 22.4 if game_time_limit else None | ||
|  |         self.realtime = realtime | ||
|  |         logger.debug( | ||
|  |             f"Proxy Inited with ctrl {controller}({controller._process._port}), player {player}, proxyport {proxyport}, lim {game_time_limit}" | ||
|  |         ) | ||
|  | 
 | ||
|  |         self.result = None | ||
|  |         self.player_id: int = None | ||
|  |         self.done = False | ||
|  | 
 | ||
|  |     async def parse_request(self, msg): | ||
|  |         request = sc_pb.Request() | ||
|  |         request.ParseFromString(msg.data) | ||
|  |         if request.HasField("quit"): | ||
|  |             request = sc_pb.Request(leave_game=sc_pb.RequestLeaveGame()) | ||
|  |         if request.HasField("leave_game"): | ||
|  |             if self.controller._status == Status.in_game: | ||
|  |                 logger.info(f"Proxy: player {self.player.name}({self.player_id}) surrenders") | ||
|  |                 self.result = {self.player_id: Result.Defeat} | ||
|  |             elif self.controller._status == Status.ended: | ||
|  |                 await self.get_response() | ||
|  |         elif request.HasField("join_game") and not request.join_game.HasField("player_name"): | ||
|  |             request.join_game.player_name = self.player.name | ||
|  |         await self.controller._ws.send_bytes(request.SerializeToString()) | ||
|  | 
 | ||
|  |     # TODO Catching too general exception Exception (broad-except) | ||
|  |     # pylint: disable=W0703 | ||
|  |     async def get_response(self): | ||
|  |         response_bytes = None | ||
|  |         try: | ||
|  |             response_bytes = await self.controller._ws.receive_bytes() | ||
|  |         except TypeError as e: | ||
|  |             logger.exception("Cannot receive: SC2 Connection already closed.") | ||
|  |             tb = traceback.format_exc() | ||
|  |             logger.error(f"Exception {e}: {tb}") | ||
|  |         except asyncio.CancelledError: | ||
|  |             logger.info(f"Proxy({self.player.name}), caught receive from sc2") | ||
|  |             try: | ||
|  |                 x = await self.controller._ws.receive_bytes() | ||
|  |                 if response_bytes is None: | ||
|  |                     response_bytes = x | ||
|  |             except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e: | ||
|  |                 logger.exception(f"Exception {e}") | ||
|  |         except Exception as e: | ||
|  |             logger.exception(f"Caught unknown exception: {e}") | ||
|  |         return response_bytes | ||
|  | 
 | ||
|  |     async def parse_response(self, response_bytes): | ||
|  |         response = sc_pb.Response() | ||
|  |         response.ParseFromString(response_bytes) | ||
|  | 
 | ||
|  |         if not response.HasField("status"): | ||
|  |             logger.critical("Proxy: RESPONSE HAS NO STATUS {response}") | ||
|  |         else: | ||
|  |             new_status = Status(response.status) | ||
|  |             if new_status != self.controller._status: | ||
|  |                 logger.info(f"Controller({self.player.name}): {self.controller._status}->{new_status}") | ||
|  |                 self.controller._status = new_status | ||
|  | 
 | ||
|  |         if self.player_id is None: | ||
|  |             if response.HasField("join_game"): | ||
|  |                 self.player_id = response.join_game.player_id | ||
|  |                 logger.info(f"Proxy({self.player.name}): got join_game for {self.player_id}") | ||
|  | 
 | ||
|  |         if self.result is None: | ||
|  |             if response.HasField("observation"): | ||
|  |                 obs: sc_pb.ResponseObservation = response.observation | ||
|  |                 if obs.player_result: | ||
|  |                     self.result = {pr.player_id: Result(pr.result) for pr in obs.player_result} | ||
|  |                 elif ( | ||
|  |                     self.timeout_loop and obs.HasField("observation") and obs.observation.game_loop > self.timeout_loop | ||
|  |                 ): | ||
|  |                     self.result = {i: Result.Tie for i in range(1, 3)} | ||
|  |                     logger.info(f"Proxy({self.player.name}) timing out") | ||
|  |                     act = [sc_pb.Action(action_chat=sc_pb.ActionChat(message="Proxy: Timing out"))] | ||
|  |                     await self.controller._execute(action=sc_pb.RequestAction(actions=act)) | ||
|  |         return response | ||
|  | 
 | ||
|  |     async def get_result(self): | ||
|  |         try: | ||
|  |             res = await self.controller.ping() | ||
|  |             if res.status in {Status.in_game, Status.in_replay, Status.ended}: | ||
|  |                 res = await self.controller._execute(observation=sc_pb.RequestObservation()) | ||
|  |                 if res.HasField("observation") and res.observation.player_result: | ||
|  |                     self.result = {pr.player_id: Result(pr.result) for pr in res.observation.player_result} | ||
|  |         # pylint: disable=W0703 | ||
|  |         # TODO Catching too general exception Exception (broad-except) | ||
|  |         except Exception as e: | ||
|  |             logger.exception(f"Caught unknown exception: {e}") | ||
|  | 
 | ||
|  |     async def proxy_handler(self, request): | ||
|  |         bot_ws = web.WebSocketResponse(receive_timeout=30) | ||
|  |         await bot_ws.prepare(request) | ||
|  |         try: | ||
|  |             async for msg in bot_ws: | ||
|  |                 if msg.data is None: | ||
|  |                     raise TypeError(f"data is None, {msg}") | ||
|  |                 if msg.data and msg.type == WSMsgType.BINARY: | ||
|  | 
 | ||
|  |                     await self.parse_request(msg) | ||
|  | 
 | ||
|  |                     response_bytes = await self.get_response() | ||
|  |                     if response_bytes is None: | ||
|  |                         raise ConnectionError("Could not get response_bytes") | ||
|  | 
 | ||
|  |                     new_response = await self.parse_response(response_bytes) | ||
|  |                     await bot_ws.send_bytes(new_response.SerializeToString()) | ||
|  | 
 | ||
|  |                 elif msg.type == WSMsgType.CLOSED: | ||
|  |                     logger.error("Client shutdown") | ||
|  |                 else: | ||
|  |                     logger.error("Incorrect message type") | ||
|  |         # pylint: disable=W0703 | ||
|  |         # TODO Catching too general exception Exception (broad-except) | ||
|  |         except Exception as e: | ||
|  |             logger.exception(f"Caught unknown exception: {e}") | ||
|  |             ignored_errors = {ConnectionError, asyncio.CancelledError} | ||
|  |             if not any(isinstance(e, E) for E in ignored_errors): | ||
|  |                 tb = traceback.format_exc() | ||
|  |                 logger.info(f"Proxy({self.player.name}): Caught {e} traceback: {tb}") | ||
|  |         finally: | ||
|  |             try: | ||
|  |                 if self.controller._status in {Status.in_game, Status.in_replay}: | ||
|  |                     await self.controller._execute(leave_game=sc_pb.RequestLeaveGame()) | ||
|  |                 await bot_ws.close() | ||
|  |             # pylint: disable=W0703 | ||
|  |             # TODO Catching too general exception Exception (broad-except) | ||
|  |             except Exception as e: | ||
|  |                 logger.exception(f"Caught unknown exception during surrender: {e}") | ||
|  |             self.done = True | ||
|  |         return bot_ws | ||
|  | 
 | ||
|  |     # pylint: disable=R0912 | ||
|  |     async def play_with_proxy(self, startport): | ||
|  |         logger.info(f"Proxy({self.port}): Starting app") | ||
|  |         app = web.Application() | ||
|  |         app.router.add_route("GET", "/sc2api", self.proxy_handler) | ||
|  |         apprunner = web.AppRunner(app, access_log=None) | ||
|  |         await apprunner.setup() | ||
|  |         appsite = web.TCPSite(apprunner, self.controller._process._host, self.port) | ||
|  |         await appsite.start() | ||
|  | 
 | ||
|  |         subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT} | ||
|  |         if platform.system() == "Linux": | ||
|  |             subproc_args["preexec_fn"] = os.setpgrp | ||
|  |         elif platform.system() == "Windows": | ||
|  |             subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | ||
|  | 
 | ||
|  |         player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime) | ||
|  |         logger.info(f"Starting bot with command: {' '.join(player_command_line)}") | ||
|  |         if self.player.stdout is None: | ||
|  |             bot_process = subprocess.Popen(player_command_line, stdout=subprocess.DEVNULL, **subproc_args) | ||
|  |         else: | ||
|  |             with open(self.player.stdout, "w+") as out: | ||
|  |                 bot_process = subprocess.Popen(player_command_line, stdout=out, **subproc_args) | ||
|  | 
 | ||
|  |         while self.result is None: | ||
|  |             bot_alive = bot_process and bot_process.poll() is None | ||
|  |             sc2_alive = self.controller.running | ||
|  |             if self.done or not (bot_alive and sc2_alive): | ||
|  |                 logger.info( | ||
|  |                     f"Proxy({self.port}): {self.player.name} died, " | ||
|  |                     f"bot{(not bot_alive) * ' not'} alive, sc2{(not sc2_alive) * ' not'} alive" | ||
|  |                 ) | ||
|  |                 # Maybe its still possible to retrieve a result | ||
|  |                 if sc2_alive and not self.done: | ||
|  |                     await self.get_response() | ||
|  |                 logger.info(f"Proxy({self.port}): breaking, result {self.result}") | ||
|  |                 break | ||
|  |             await asyncio.sleep(5) | ||
|  | 
 | ||
|  |         # cleanup | ||
|  |         logger.info(f"({self.port}): cleaning up {self.player !r}") | ||
|  |         for _i in range(3): | ||
|  |             if isinstance(bot_process, subprocess.Popen): | ||
|  |                 if bot_process.stdout and not bot_process.stdout.closed:  # should not run anymore | ||
|  |                     logger.info(f"==================output for player {self.player.name}") | ||
|  |                     for l in bot_process.stdout.readlines(): | ||
|  |                         logger.opt(raw=True).info(l.decode("utf-8")) | ||
|  |                     bot_process.stdout.close() | ||
|  |                     logger.info("==================") | ||
|  |                 bot_process.terminate() | ||
|  |                 bot_process.wait() | ||
|  |             time.sleep(0.5) | ||
|  |             if not bot_process or bot_process.poll() is not None: | ||
|  |                 break | ||
|  |         else: | ||
|  |             bot_process.terminate() | ||
|  |             bot_process.wait() | ||
|  |         try: | ||
|  |             await apprunner.cleanup() | ||
|  |         # pylint: disable=W0703 | ||
|  |         # TODO Catching too general exception Exception (broad-except) | ||
|  |         except Exception as e: | ||
|  |             logger.exception(f"Caught unknown exception during cleaning: {e}") | ||
|  |         if isinstance(self.result, dict): | ||
|  |             self.result[None] = None | ||
|  |             return self.result[self.player_id] | ||
|  |         return self.result |