276 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			276 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import asyncio | ||
|  | import os | ||
|  | import os.path | ||
|  | import shutil | ||
|  | import signal | ||
|  | import subprocess | ||
|  | import sys | ||
|  | import tempfile | ||
|  | import time | ||
|  | from contextlib import suppress | ||
|  | from typing import Any, Dict, List, Optional, Tuple, Union | ||
|  | 
 | ||
|  | import aiohttp | ||
|  | import portpicker | ||
|  | from worlds._sc2common.bot import logger | ||
|  | 
 | ||
|  | from . import paths, wsl | ||
|  | from .controller import Controller | ||
|  | from .paths import Paths | ||
|  | from .versions import VERSIONS | ||
|  | 
 | ||
|  | 
 | ||
|  | class kill_switch: | ||
|  |     _to_kill: List[Any] = [] | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def add(cls, value): | ||
|  |         logger.debug("kill_switch: Add switch") | ||
|  |         cls._to_kill.append(value) | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def kill_all(cls): | ||
|  |         logger.info(f"kill_switch: Process cleanup for {len(cls._to_kill)} processes") | ||
|  |         for p in cls._to_kill: | ||
|  |             # pylint: disable=W0212 | ||
|  |             p._clean(verbose=False) | ||
|  | 
 | ||
|  | 
 | ||
|  | class SC2Process: | ||
|  |     """
 | ||
|  |     A class for handling SCII applications. | ||
|  | 
 | ||
|  |     :param host: hostname for the url the SCII application will listen to | ||
|  |     :param port: the websocket port the SCII application will listen to | ||
|  |     :param fullscreen: whether to launch the SCII application in fullscreen or not, defaults to False | ||
|  |     :param resolution: (window width, window height) in pixels, defaults to (1024, 768) | ||
|  |     :param placement: (x, y) the distances of the SCII app's top left corner from the top left corner of the screen | ||
|  |                        e.g. (20, 30) is 20 to the right of the screen's left border, and 30 below the top border | ||
|  |     :param render: | ||
|  |     :param sc2_version: | ||
|  |     :param base_build: | ||
|  |     :param data_hash: | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def __init__( | ||
|  |         self, | ||
|  |         host: Optional[str] = None, | ||
|  |         port: Optional[int] = None, | ||
|  |         fullscreen: bool = False, | ||
|  |         resolution: Optional[Union[List[int], Tuple[int, int]]] = None, | ||
|  |         placement: Optional[Union[List[int], Tuple[int, int]]] = None, | ||
|  |         render: bool = False, | ||
|  |         sc2_version: str = None, | ||
|  |         base_build: str = None, | ||
|  |         data_hash: str = None, | ||
|  |     ) -> None: | ||
|  |         assert isinstance(host, str) or host is None | ||
|  |         assert isinstance(port, int) or port is None | ||
|  | 
 | ||
|  |         self._render = render | ||
|  |         self._arguments: Dict[str, str] = {"-displayMode": str(int(fullscreen))} | ||
|  |         if not fullscreen: | ||
|  |             if resolution and len(resolution) == 2: | ||
|  |                 self._arguments["-windowwidth"] = str(resolution[0]) | ||
|  |                 self._arguments["-windowheight"] = str(resolution[1]) | ||
|  |             if placement and len(placement) == 2: | ||
|  |                 self._arguments["-windowx"] = str(placement[0]) | ||
|  |                 self._arguments["-windowy"] = str(placement[1]) | ||
|  | 
 | ||
|  |         self._host = host or os.environ.get("SC2CLIENTHOST", "127.0.0.1") | ||
|  |         self._serverhost = os.environ.get("SC2SERVERHOST", self._host) | ||
|  | 
 | ||
|  |         if port is None: | ||
|  |             self._port = portpicker.pick_unused_port() | ||
|  |         else: | ||
|  |             self._port = port | ||
|  |         self._used_portpicker = bool(port is None) | ||
|  |         self._tmp_dir = tempfile.mkdtemp(prefix="SC2_") | ||
|  |         self._process: subprocess = None | ||
|  |         self._session = None | ||
|  |         self._ws = None | ||
|  |         self._sc2_version = sc2_version | ||
|  |         self._base_build = base_build | ||
|  |         self._data_hash = data_hash | ||
|  | 
 | ||
|  |     async def __aenter__(self) -> Controller: | ||
|  |         kill_switch.add(self) | ||
|  | 
 | ||
|  |         def signal_handler(*_args): | ||
|  |             # unused arguments: signal handling library expects all signal | ||
|  |             # callback handlers to accept two positional arguments | ||
|  |             kill_switch.kill_all() | ||
|  | 
 | ||
|  |         signal.signal(signal.SIGINT, signal_handler) | ||
|  | 
 | ||
|  |         try: | ||
|  |             self._process = self._launch() | ||
|  |             self._ws = await self._connect() | ||
|  |         except: | ||
|  |             await self._close_connection() | ||
|  |             self._clean() | ||
|  |             raise | ||
|  | 
 | ||
|  |         return Controller(self._ws, self) | ||
|  | 
 | ||
|  |     async def __aexit__(self, *args): | ||
|  |         logger.exception("async exit") | ||
|  |         await self._close_connection() | ||
|  |         kill_switch.kill_all() | ||
|  |         signal.signal(signal.SIGINT, signal.SIG_DFL) | ||
|  | 
 | ||
|  |     @property | ||
|  |     def ws_url(self): | ||
|  |         return f"ws://{self._host}:{self._port}/sc2api" | ||
|  | 
 | ||
|  |     @property | ||
|  |     def versions(self): | ||
|  |         """Opens the versions.json file which origins from
 | ||
|  |         https://github.com/Blizzard/s2client-proto/blob/master/buildinfo/versions.json"""
 | ||
|  |         return VERSIONS | ||
|  | 
 | ||
|  |     def find_data_hash(self, target_sc2_version: str) -> Optional[str]: | ||
|  |         """ Returns the data hash from the matching version string. """ | ||
|  |         version: dict | ||
|  |         for version in self.versions: | ||
|  |             if version["label"] == target_sc2_version: | ||
|  |                 return version["data-hash"] | ||
|  |         return None | ||
|  | 
 | ||
|  |     def _launch(self): | ||
|  |         if self._base_build: | ||
|  |             executable = str(paths.latest_executeble(Paths.BASE / "Versions", self._base_build)) | ||
|  |         else: | ||
|  |             executable = str(Paths.EXECUTABLE) | ||
|  |         if self._port is None: | ||
|  |             self._port = portpicker.pick_unused_port() | ||
|  |             self._used_portpicker = True | ||
|  |         args = paths.get_runner_args(Paths.CWD) + [ | ||
|  |             executable, | ||
|  |             "-listen", | ||
|  |             self._serverhost, | ||
|  |             "-port", | ||
|  |             str(self._port), | ||
|  |             "-dataDir", | ||
|  |             str(Paths.BASE), | ||
|  |             "-tempDir", | ||
|  |             self._tmp_dir, | ||
|  |         ] | ||
|  |         for arg, value in self._arguments.items(): | ||
|  |             args.append(arg) | ||
|  |             args.append(value) | ||
|  |         if self._sc2_version: | ||
|  | 
 | ||
|  |             def special_match(strg: str): | ||
|  |                 """ Tests if the specified version is in the versions.py dict. """ | ||
|  |                 for version in self.versions: | ||
|  |                     if version["label"] == strg: | ||
|  |                         return True | ||
|  |                 return False | ||
|  | 
 | ||
|  |             valid_version_string = special_match(self._sc2_version) | ||
|  |             if valid_version_string: | ||
|  |                 self._data_hash = self.find_data_hash(self._sc2_version) | ||
|  |                 assert ( | ||
|  |                     self._data_hash is not None | ||
|  |                 ), f"StarCraft 2 Client version ({self._sc2_version}) was not found inside sc2/versions.py file. Please check your spelling or check the versions.py file." | ||
|  | 
 | ||
|  |             else: | ||
|  |                 logger.warning( | ||
|  |                     f'The submitted version string in sc2.rungame() function call (sc2_version="{self._sc2_version}") was not found in versions.py. Running latest version instead.' | ||
|  |                 ) | ||
|  | 
 | ||
|  |         if self._data_hash: | ||
|  |             args.extend(["-dataVersion", self._data_hash]) | ||
|  | 
 | ||
|  |         if self._render: | ||
|  |             args.extend(["-eglpath", "libEGL.so"]) | ||
|  | 
 | ||
|  |         # if logger.getEffectiveLevel() <= logging.DEBUG: | ||
|  |         args.append("-verbose") | ||
|  | 
 | ||
|  |         sc2_cwd = str(Paths.CWD) if Paths.CWD else None | ||
|  | 
 | ||
|  |         if paths.PF in {"WSL1", "WSL2"}: | ||
|  |             return wsl.run(args, sc2_cwd) | ||
|  | 
 | ||
|  |         return subprocess.Popen( | ||
|  |             args, | ||
|  |             cwd=sc2_cwd, | ||
|  |             # Suppress Wine error messages | ||
|  |             stderr=subprocess.DEVNULL | ||
|  |             # , env=run_config.env | ||
|  |         ) | ||
|  | 
 | ||
|  |     async def _connect(self): | ||
|  |         # How long it waits for SC2 to start (in seconds) | ||
|  |         for i in range(180): | ||
|  |             if self._process is None: | ||
|  |                 # The ._clean() was called, clearing the process | ||
|  |                 logger.debug("Process cleanup complete, exit") | ||
|  |                 sys.exit() | ||
|  | 
 | ||
|  |             await asyncio.sleep(1) | ||
|  |             try: | ||
|  |                 self._session = aiohttp.ClientSession() | ||
|  |                 ws = await self._session.ws_connect(self.ws_url, timeout=120) | ||
|  |                 # FIXME fix deprecation warning in for future aiohttp version | ||
|  |                 # ws = await self._session.ws_connect( | ||
|  |                 #     self.ws_url, timeout=aiohttp.client_ws.ClientWSTimeout(ws_close=120) | ||
|  |                 # ) | ||
|  |                 logger.debug("Websocket connection ready") | ||
|  |                 return ws | ||
|  |             except aiohttp.client_exceptions.ClientConnectorError: | ||
|  |                 await self._session.close() | ||
|  |                 if i > 15: | ||
|  |                     logger.debug("Connection refused (startup not complete (yet))") | ||
|  | 
 | ||
|  |         logger.debug("Websocket connection to SC2 process timed out") | ||
|  |         raise TimeoutError("Websocket") | ||
|  | 
 | ||
|  |     async def _close_connection(self): | ||
|  |         logger.info(f"Closing connection at {self._port}...") | ||
|  | 
 | ||
|  |         if self._ws is not None: | ||
|  |             await self._ws.close() | ||
|  | 
 | ||
|  |         if self._session is not None: | ||
|  |             await self._session.close() | ||
|  | 
 | ||
|  |     # pylint: disable=R0912 | ||
|  |     def _clean(self, verbose=True): | ||
|  |         if verbose: | ||
|  |             logger.info("Cleaning up...") | ||
|  | 
 | ||
|  |         if self._process is not None: | ||
|  |             if paths.PF in {"WSL1", "WSL2"}: | ||
|  |                 if wsl.kill(self._process): | ||
|  |                     logger.error("KILLED") | ||
|  |             elif self._process.poll() is None: | ||
|  |                 for _ in range(3): | ||
|  |                     self._process.terminate() | ||
|  |                     time.sleep(0.5) | ||
|  |                     if not self._process or self._process.poll() is not None: | ||
|  |                         break | ||
|  |             else: | ||
|  |                 self._process.kill() | ||
|  |                 self._process.wait() | ||
|  |                 logger.error("KILLED") | ||
|  |             # Try to kill wineserver on linux | ||
|  |             if paths.PF in {"Linux", "WineLinux"}: | ||
|  |                 # Command wineserver not detected | ||
|  |                 with suppress(FileNotFoundError): | ||
|  |                     with subprocess.Popen(["wineserver", "-k"]) as p: | ||
|  |                         p.wait() | ||
|  | 
 | ||
|  |         if os.path.exists(self._tmp_dir): | ||
|  |             shutil.rmtree(self._tmp_dir) | ||
|  | 
 | ||
|  |         self._process = None | ||
|  |         self._ws = None | ||
|  |         if self._used_portpicker and self._port is not None: | ||
|  |             portpicker.return_port(self._port) | ||
|  |             self._port = None | ||
|  |         if verbose: | ||
|  |             logger.info("Cleanup complete") |