mirror of
				https://github.com/MarioSpore/Grinch-AP.git
				synced 2025-10-21 20:21:32 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			276 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			276 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import os
 | |
| import os.path
 | |
| import shutil
 | |
| import signal
 | |
| import subprocess
 | |
| import sys
 | |
| import tempfile
 | |
| import time
 | |
| from contextlib import suppress
 | |
| from typing import Any, Dict, List, Optional, Tuple, Union
 | |
| 
 | |
| import aiohttp
 | |
| import portpicker
 | |
| from worlds._sc2common.bot import logger
 | |
| 
 | |
| from . import paths, wsl
 | |
| from .controller import Controller
 | |
| from .paths import Paths
 | |
| from .versions import VERSIONS
 | |
| 
 | |
| 
 | |
| class kill_switch:
 | |
|     _to_kill: List[Any] = []
 | |
| 
 | |
|     @classmethod
 | |
|     def add(cls, value):
 | |
|         logger.debug("kill_switch: Add switch")
 | |
|         cls._to_kill.append(value)
 | |
| 
 | |
|     @classmethod
 | |
|     def kill_all(cls):
 | |
|         logger.info(f"kill_switch: Process cleanup for {len(cls._to_kill)} processes")
 | |
|         for p in cls._to_kill:
 | |
|             # pylint: disable=W0212
 | |
|             p._clean(verbose=False)
 | |
| 
 | |
| 
 | |
| class SC2Process:
 | |
|     """
 | |
|     A class for handling SCII applications.
 | |
| 
 | |
|     :param host: hostname for the url the SCII application will listen to
 | |
|     :param port: the websocket port the SCII application will listen to
 | |
|     :param fullscreen: whether to launch the SCII application in fullscreen or not, defaults to False
 | |
|     :param resolution: (window width, window height) in pixels, defaults to (1024, 768)
 | |
|     :param placement: (x, y) the distances of the SCII app's top left corner from the top left corner of the screen
 | |
|                        e.g. (20, 30) is 20 to the right of the screen's left border, and 30 below the top border
 | |
|     :param render:
 | |
|     :param sc2_version:
 | |
|     :param base_build:
 | |
|     :param data_hash:
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         host: Optional[str] = None,
 | |
|         port: Optional[int] = None,
 | |
|         fullscreen: bool = False,
 | |
|         resolution: Optional[Union[List[int], Tuple[int, int]]] = None,
 | |
|         placement: Optional[Union[List[int], Tuple[int, int]]] = None,
 | |
|         render: bool = False,
 | |
|         sc2_version: str = None,
 | |
|         base_build: str = None,
 | |
|         data_hash: str = None,
 | |
|     ) -> None:
 | |
|         assert isinstance(host, str) or host is None
 | |
|         assert isinstance(port, int) or port is None
 | |
| 
 | |
|         self._render = render
 | |
|         self._arguments: Dict[str, str] = {"-displayMode": str(int(fullscreen))}
 | |
|         if not fullscreen:
 | |
|             if resolution and len(resolution) == 2:
 | |
|                 self._arguments["-windowwidth"] = str(resolution[0])
 | |
|                 self._arguments["-windowheight"] = str(resolution[1])
 | |
|             if placement and len(placement) == 2:
 | |
|                 self._arguments["-windowx"] = str(placement[0])
 | |
|                 self._arguments["-windowy"] = str(placement[1])
 | |
| 
 | |
|         self._host = host or os.environ.get("SC2CLIENTHOST", "127.0.0.1")
 | |
|         self._serverhost = os.environ.get("SC2SERVERHOST", self._host)
 | |
| 
 | |
|         if port is None:
 | |
|             self._port = portpicker.pick_unused_port()
 | |
|         else:
 | |
|             self._port = port
 | |
|         self._used_portpicker = bool(port is None)
 | |
|         self._tmp_dir = tempfile.mkdtemp(prefix="SC2_")
 | |
|         self._process: subprocess = None
 | |
|         self._session = None
 | |
|         self._ws = None
 | |
|         self._sc2_version = sc2_version
 | |
|         self._base_build = base_build
 | |
|         self._data_hash = data_hash
 | |
| 
 | |
|     async def __aenter__(self) -> Controller:
 | |
|         kill_switch.add(self)
 | |
| 
 | |
|         def signal_handler(*_args):
 | |
|             # unused arguments: signal handling library expects all signal
 | |
|             # callback handlers to accept two positional arguments
 | |
|             kill_switch.kill_all()
 | |
| 
 | |
|         signal.signal(signal.SIGINT, signal_handler)
 | |
| 
 | |
|         try:
 | |
|             self._process = self._launch()
 | |
|             self._ws = await self._connect()
 | |
|         except:
 | |
|             await self._close_connection()
 | |
|             self._clean()
 | |
|             raise
 | |
| 
 | |
|         return Controller(self._ws, self)
 | |
| 
 | |
|     async def __aexit__(self, *args):
 | |
|         logger.exception("async exit")
 | |
|         await self._close_connection()
 | |
|         kill_switch.kill_all()
 | |
|         signal.signal(signal.SIGINT, signal.SIG_DFL)
 | |
| 
 | |
|     @property
 | |
|     def ws_url(self):
 | |
|         return f"ws://{self._host}:{self._port}/sc2api"
 | |
| 
 | |
|     @property
 | |
|     def versions(self):
 | |
|         """Opens the versions.json file which origins from
 | |
|         https://github.com/Blizzard/s2client-proto/blob/master/buildinfo/versions.json"""
 | |
|         return VERSIONS
 | |
| 
 | |
|     def find_data_hash(self, target_sc2_version: str) -> Optional[str]:
 | |
|         """ Returns the data hash from the matching version string. """
 | |
|         version: dict
 | |
|         for version in self.versions:
 | |
|             if version["label"] == target_sc2_version:
 | |
|                 return version["data-hash"]
 | |
|         return None
 | |
| 
 | |
|     def _launch(self):
 | |
|         if self._base_build:
 | |
|             executable = str(paths.latest_executeble(Paths.BASE / "Versions", self._base_build))
 | |
|         else:
 | |
|             executable = str(Paths.EXECUTABLE)
 | |
|         if self._port is None:
 | |
|             self._port = portpicker.pick_unused_port()
 | |
|             self._used_portpicker = True
 | |
|         args = paths.get_runner_args(Paths.CWD) + [
 | |
|             executable,
 | |
|             "-listen",
 | |
|             self._serverhost,
 | |
|             "-port",
 | |
|             str(self._port),
 | |
|             "-dataDir",
 | |
|             str(Paths.BASE),
 | |
|             "-tempDir",
 | |
|             self._tmp_dir,
 | |
|         ]
 | |
|         for arg, value in self._arguments.items():
 | |
|             args.append(arg)
 | |
|             args.append(value)
 | |
|         if self._sc2_version:
 | |
| 
 | |
|             def special_match(strg: str):
 | |
|                 """ Tests if the specified version is in the versions.py dict. """
 | |
|                 for version in self.versions:
 | |
|                     if version["label"] == strg:
 | |
|                         return True
 | |
|                 return False
 | |
| 
 | |
|             valid_version_string = special_match(self._sc2_version)
 | |
|             if valid_version_string:
 | |
|                 self._data_hash = self.find_data_hash(self._sc2_version)
 | |
|                 assert (
 | |
|                     self._data_hash is not None
 | |
|                 ), f"StarCraft 2 Client version ({self._sc2_version}) was not found inside sc2/versions.py file. Please check your spelling or check the versions.py file."
 | |
| 
 | |
|             else:
 | |
|                 logger.warning(
 | |
|                     f'The submitted version string in sc2.rungame() function call (sc2_version="{self._sc2_version}") was not found in versions.py. Running latest version instead.'
 | |
|                 )
 | |
| 
 | |
|         if self._data_hash:
 | |
|             args.extend(["-dataVersion", self._data_hash])
 | |
| 
 | |
|         if self._render:
 | |
|             args.extend(["-eglpath", "libEGL.so"])
 | |
| 
 | |
|         # if logger.getEffectiveLevel() <= logging.DEBUG:
 | |
|         args.append("-verbose")
 | |
| 
 | |
|         sc2_cwd = str(Paths.CWD) if Paths.CWD else None
 | |
| 
 | |
|         if paths.PF in {"WSL1", "WSL2"}:
 | |
|             return wsl.run(args, sc2_cwd)
 | |
| 
 | |
|         return subprocess.Popen(
 | |
|             args,
 | |
|             cwd=sc2_cwd,
 | |
|             # Suppress Wine error messages
 | |
|             stderr=subprocess.DEVNULL
 | |
|             # , env=run_config.env
 | |
|         )
 | |
| 
 | |
|     async def _connect(self):
 | |
|         # How long it waits for SC2 to start (in seconds)
 | |
|         for i in range(180):
 | |
|             if self._process is None:
 | |
|                 # The ._clean() was called, clearing the process
 | |
|                 logger.debug("Process cleanup complete, exit")
 | |
|                 sys.exit()
 | |
| 
 | |
|             await asyncio.sleep(1)
 | |
|             try:
 | |
|                 self._session = aiohttp.ClientSession()
 | |
|                 ws = await self._session.ws_connect(self.ws_url, timeout=120)
 | |
|                 # FIXME fix deprecation warning in for future aiohttp version
 | |
|                 # ws = await self._session.ws_connect(
 | |
|                 #     self.ws_url, timeout=aiohttp.client_ws.ClientWSTimeout(ws_close=120)
 | |
|                 # )
 | |
|                 logger.debug("Websocket connection ready")
 | |
|                 return ws
 | |
|             except aiohttp.client_exceptions.ClientConnectorError:
 | |
|                 await self._session.close()
 | |
|                 if i > 15:
 | |
|                     logger.debug("Connection refused (startup not complete (yet))")
 | |
| 
 | |
|         logger.debug("Websocket connection to SC2 process timed out")
 | |
|         raise TimeoutError("Websocket")
 | |
| 
 | |
|     async def _close_connection(self):
 | |
|         logger.info(f"Closing connection at {self._port}...")
 | |
| 
 | |
|         if self._ws is not None:
 | |
|             await self._ws.close()
 | |
| 
 | |
|         if self._session is not None:
 | |
|             await self._session.close()
 | |
| 
 | |
|     # pylint: disable=R0912
 | |
|     def _clean(self, verbose=True):
 | |
|         if verbose:
 | |
|             logger.info("Cleaning up...")
 | |
| 
 | |
|         if self._process is not None:
 | |
|             if paths.PF in {"WSL1", "WSL2"}:
 | |
|                 if wsl.kill(self._process):
 | |
|                     logger.error("KILLED")
 | |
|             elif self._process.poll() is None:
 | |
|                 for _ in range(3):
 | |
|                     self._process.terminate()
 | |
|                     time.sleep(0.5)
 | |
|                     if not self._process or self._process.poll() is not None:
 | |
|                         break
 | |
|             else:
 | |
|                 self._process.kill()
 | |
|                 self._process.wait()
 | |
|                 logger.error("KILLED")
 | |
|             # Try to kill wineserver on linux
 | |
|             if paths.PF in {"Linux", "WineLinux"}:
 | |
|                 # Command wineserver not detected
 | |
|                 with suppress(FileNotFoundError):
 | |
|                     with subprocess.Popen(["wineserver", "-k"]) as p:
 | |
|                         p.wait()
 | |
| 
 | |
|         if os.path.exists(self._tmp_dir):
 | |
|             shutil.rmtree(self._tmp_dir)
 | |
| 
 | |
|         self._process = None
 | |
|         self._ws = None
 | |
|         if self._used_portpicker and self._port is not None:
 | |
|             portpicker.return_port(self._port)
 | |
|             self._port = None
 | |
|         if verbose:
 | |
|             logger.info("Cleanup complete")
 | 
