| 
									
										
										
										
											2020-04-22 05:09:46 +02:00
										 |  |  | from __future__ import annotations | 
					
						
							| 
									
										
										
										
											2020-06-21 15:32:31 +02:00
										 |  |  | import typing | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def tuplize_version(version: str) -> typing.Tuple[int, ...]: | 
					
						
							|  |  |  |     return tuple(int(piece, 10) for piece in version.split(".")) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-22 05:09:46 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-30 00:10:41 +01:00
										 |  |  | __version__ = "3.3.0" | 
					
						
							| 
									
										
										
										
											2020-06-21 15:32:31 +02:00
										 |  |  | _version_tuple = tuplize_version(__version__) | 
					
						
							| 
									
										
										
										
											2020-04-20 14:50:49 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2017-12-17 00:25:46 -05:00
										 |  |  | import subprocess | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2020-09-09 01:41:37 +02:00
										 |  |  | import pickle | 
					
						
							|  |  |  | import io | 
					
						
							|  |  |  | import builtins | 
					
						
							| 
									
										
										
										
											2020-06-21 15:32:31 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | import functools | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-05 02:06:00 +02:00
										 |  |  | from yaml import load, dump, safe_load | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     from yaml import CLoader as Loader | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     from yaml import Loader | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-17 18:38:54 -05:00
										 |  |  | def int16_as_bytes(value): | 
					
						
							|  |  |  |     value = value & 0xFFFF | 
					
						
							|  |  |  |     return [value & 0xFF, (value >> 8) & 0xFF] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-17 18:38:54 -05:00
										 |  |  | def int32_as_bytes(value): | 
					
						
							|  |  |  |     value = value & 0xFFFFFFFF | 
					
						
							|  |  |  |     return [value & 0xFF, (value >> 8) & 0xFF, (value >> 16) & 0xFF, (value >> 24) & 0xFF] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-22 22:51:54 -04:00
										 |  |  | def pc_to_snes(value): | 
					
						
							|  |  |  |     return ((value<<1) & 0x7F0000)|(value & 0x7FFF)|0x8000 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-21 23:15:19 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-22 22:51:54 -04:00
										 |  |  | def snes_to_pc(value): | 
					
						
							|  |  |  |     return ((value & 0x7F0000)>>1)|(value & 0x7FFF) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-21 23:15:19 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-14 10:42:27 +01:00
										 |  |  | def parse_player_names(names, players, teams): | 
					
						
							| 
									
										
										
										
											2020-03-02 23:27:16 +01:00
										 |  |  |     names = tuple(n for n in (n.strip() for n in names.split(",")) if n) | 
					
						
							| 
									
										
										
										
											2020-08-20 04:03:49 +02:00
										 |  |  |     if len(names) != len(set(names)): | 
					
						
							|  |  |  |         raise ValueError("Duplicate Player names is not supported.") | 
					
						
							| 
									
										
										
										
											2020-01-14 10:42:27 +01:00
										 |  |  |     ret = [] | 
					
						
							|  |  |  |     while names or len(ret) < teams: | 
					
						
							|  |  |  |         team = [n[:16] for n in names[:players]] | 
					
						
							| 
									
										
										
										
											2020-08-20 04:03:49 +02:00
										 |  |  |         # 16 bytes in rom per player, which will map to more in unicode, but those characters later get filtered | 
					
						
							| 
									
										
										
										
											2020-01-14 10:42:27 +01:00
										 |  |  |         while len(team) != players: | 
					
						
							| 
									
										
										
										
											2020-04-18 21:46:57 +02:00
										 |  |  |             team.append(f"Player{len(team) + 1}") | 
					
						
							| 
									
										
										
										
											2020-01-14 10:42:27 +01:00
										 |  |  |         ret.append(team) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         names = names[players:] | 
					
						
							|  |  |  |     return ret | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-21 23:15:19 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def is_bundled() -> bool: | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  |     return getattr(sys, 'frozen', False) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-21 23:15:19 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  | def local_path(*path): | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  |     if local_path.cached_path: | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  |         return os.path.join(local_path.cached_path, *path) | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-23 07:45:40 +01:00
										 |  |  |     elif is_bundled(): | 
					
						
							|  |  |  |         if hasattr(sys, "_MEIPASS"): | 
					
						
							|  |  |  |             # we are running in a PyInstaller bundle | 
					
						
							|  |  |  |             local_path.cached_path = sys._MEIPASS  # pylint: disable=protected-access,no-member | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # cx_Freeze | 
					
						
							|  |  |  |             local_path.cached_path = os.path.dirname(os.path.abspath(sys.argv[0])) | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  |     else: | 
					
						
							|  |  |  |         # we are running in a normal Python environment | 
					
						
							| 
									
										
										
										
											2020-03-23 07:45:40 +01:00
										 |  |  |         import __main__ | 
					
						
							|  |  |  |         local_path.cached_path = os.path.dirname(os.path.abspath(__main__.__file__)) | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  |     return os.path.join(local_path.cached_path, *path) | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | local_path.cached_path = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def output_path(*path): | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  |     if output_path.cached_path: | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  |         return os.path.join(output_path.cached_path, *path) | 
					
						
							| 
									
										
										
										
											2020-08-20 15:43:22 +02:00
										 |  |  |     output_path.cached_path = local_path(get_options()["general_options"]["output_path"]) | 
					
						
							| 
									
										
										
										
											2020-08-25 13:22:47 +02:00
										 |  |  |     path = os.path.join(output_path.cached_path, *path) | 
					
						
							| 
									
										
										
										
											2020-08-01 16:52:11 +02:00
										 |  |  |     os.makedirs(os.path.dirname(path), exist_ok=True) | 
					
						
							|  |  |  |     return path | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | output_path.cached_path = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def open_file(filename): | 
					
						
							|  |  |  |     if sys.platform == 'win32': | 
					
						
							|  |  |  |         os.startfile(filename) | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2017-12-17 00:25:46 -05:00
										 |  |  |         open_command = 'open' if sys.platform == 'darwin' else 'xdg-open' | 
					
						
							| 
									
										
										
										
											2017-11-28 09:36:32 -05:00
										 |  |  |         subprocess.call([open_command, filename]) | 
					
						
							| 
									
										
										
										
											2017-12-02 09:21:04 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | def close_console(): | 
					
						
							|  |  |  |     if sys.platform == 'win32': | 
					
						
							|  |  |  |         #windows | 
					
						
							|  |  |  |         import ctypes.wintypes | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             ctypes.windll.kernel32.FreeConsole() | 
					
						
							| 
									
										
										
										
											2017-12-17 00:25:46 -05:00
										 |  |  |         except Exception: | 
					
						
							| 
									
										
										
										
											2017-12-02 09:21:04 -05:00
										 |  |  |             pass | 
					
						
							| 
									
										
										
										
											2018-01-01 14:42:23 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-09 05:28:48 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-05 02:06:00 +02:00
										 |  |  | parse_yaml = safe_load | 
					
						
							|  |  |  | unsafe_parse_yaml = functools.partial(load, Loader=Loader) | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-09 05:28:48 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-16 15:32:40 +01:00
										 |  |  | class Hint(typing.NamedTuple): | 
					
						
							|  |  |  |     receiving_player: int | 
					
						
							|  |  |  |     finding_player: int | 
					
						
							|  |  |  |     location: int | 
					
						
							|  |  |  |     item: int | 
					
						
							|  |  |  |     found: bool | 
					
						
							| 
									
										
										
										
											2020-05-18 05:40:36 +02:00
										 |  |  |     entrance: str = "" | 
					
						
							| 
									
										
										
										
											2020-03-06 00:48:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-22 05:09:46 +02:00
										 |  |  |     def re_check(self, ctx, team) -> Hint: | 
					
						
							|  |  |  |         if self.found: | 
					
						
							|  |  |  |             return self | 
					
						
							|  |  |  |         found = self.location in ctx.location_checks[team, self.finding_player] | 
					
						
							|  |  |  |         if found: | 
					
						
							| 
									
										
										
										
											2020-05-18 05:40:36 +02:00
										 |  |  |             return Hint(self.receiving_player, self.finding_player, self.location, self.item, found, self.entrance) | 
					
						
							| 
									
										
										
										
											2020-04-22 05:09:46 +02:00
										 |  |  |         return self | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-18 05:40:36 +02:00
										 |  |  |     def as_legacy(self) -> tuple: | 
					
						
							|  |  |  |         return self.receiving_player, self.finding_player, self.location, self.item, self.found | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-22 15:50:14 +02:00
										 |  |  |     def __hash__(self): | 
					
						
							| 
									
										
										
										
											2020-05-18 05:40:36 +02:00
										 |  |  |         return hash((self.receiving_player, self.finding_player, self.location, self.item, self.entrance)) | 
					
						
							| 
									
										
										
										
											2020-04-22 15:50:14 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-06 00:48:23 +01:00
										 |  |  | def get_public_ipv4() -> str: | 
					
						
							|  |  |  |     import socket | 
					
						
							|  |  |  |     import urllib.request | 
					
						
							|  |  |  |     import logging | 
					
						
							|  |  |  |     ip = socket.gethostbyname(socket.gethostname()) | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         ip = urllib.request.urlopen('https://checkip.amazonaws.com/').read().decode('utf8').strip() | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             ip = urllib.request.urlopen('https://v4.ident.me').read().decode('utf8').strip() | 
					
						
							|  |  |  |         except: | 
					
						
							|  |  |  |             logging.exception(e) | 
					
						
							|  |  |  |             pass  # we could be offline, in a local game, so no point in erroring out | 
					
						
							|  |  |  |     return ip | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-14 09:06:37 +02:00
										 |  |  | def get_public_ipv6() -> str: | 
					
						
							|  |  |  |     import socket | 
					
						
							|  |  |  |     import urllib.request | 
					
						
							|  |  |  |     import logging | 
					
						
							|  |  |  |     ip = socket.gethostbyname(socket.gethostname()) | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         ip = urllib.request.urlopen('https://v6.ident.me').read().decode('utf8').strip() | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logging.exception(e) | 
					
						
							| 
									
										
										
										
											2020-06-21 16:13:42 +02:00
										 |  |  |         pass  # we could be offline, in a local game, or ipv6 may not be available | 
					
						
							| 
									
										
										
										
											2020-06-14 09:06:37 +02:00
										 |  |  |     return ip | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def get_options() -> dict: | 
					
						
							| 
									
										
										
										
											2020-03-23 07:59:55 +01:00
										 |  |  |     if not hasattr(get_options, "options"): | 
					
						
							| 
									
										
										
										
											2020-03-15 19:32:00 +01:00
										 |  |  |         locations = ("options.yaml", "host.yaml", | 
					
						
							|  |  |  |                      local_path("options.yaml"), local_path("host.yaml")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for location in locations: | 
					
						
							|  |  |  |             if os.path.exists(location): | 
					
						
							|  |  |  |                 with open(location) as f: | 
					
						
							| 
									
										
										
										
											2020-03-23 07:59:55 +01:00
										 |  |  |                     get_options.options = parse_yaml(f.read()) | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise FileNotFoundError(f"Could not find {locations[1]} to load options.") | 
					
						
							|  |  |  |     return get_options.options | 
					
						
							| 
									
										
										
										
											2020-04-14 20:22:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_item_name_from_id(code): | 
					
						
							|  |  |  |     import Items | 
					
						
							|  |  |  |     return Items.lookup_id_to_name.get(code, f'Unknown item (ID:{code})') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_location_name_from_address(address): | 
					
						
							|  |  |  |     import Regions | 
					
						
							|  |  |  |     return Regions.lookup_id_to_name.get(address, f'Unknown location (ID:{address})') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-24 05:29:02 +02:00
										 |  |  | def persistent_store(category, key, value): | 
					
						
							|  |  |  |     path = local_path("_persistent_storage.yaml") | 
					
						
							|  |  |  |     storage: dict = persistent_load() | 
					
						
							|  |  |  |     category = storage.setdefault(category, {}) | 
					
						
							|  |  |  |     category[key] = value | 
					
						
							|  |  |  |     with open(path, "wt") as f: | 
					
						
							|  |  |  |         f.write(dump(storage)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-26 15:14:30 +02:00
										 |  |  | def persistent_load() -> typing.Dict[dict]: | 
					
						
							| 
									
										
										
										
											2020-06-04 21:27:29 +02:00
										 |  |  |     storage = getattr(persistent_load, "storage", None) | 
					
						
							|  |  |  |     if storage: | 
					
						
							|  |  |  |         return storage | 
					
						
							| 
									
										
										
										
											2020-04-24 05:29:02 +02:00
										 |  |  |     path = local_path("_persistent_storage.yaml") | 
					
						
							|  |  |  |     storage: dict = {} | 
					
						
							|  |  |  |     if os.path.exists(path): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with open(path, "r") as f: | 
					
						
							| 
									
										
										
										
											2020-07-05 02:06:00 +02:00
										 |  |  |                 storage = unsafe_parse_yaml(f.read()) | 
					
						
							| 
									
										
										
										
											2020-04-24 05:29:02 +02:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             import logging | 
					
						
							|  |  |  |             logging.debug(f"Could not read store: {e}") | 
					
						
							| 
									
										
										
										
											2020-04-29 22:42:26 -07:00
										 |  |  |     if storage is None: | 
					
						
							|  |  |  |         storage = {} | 
					
						
							| 
									
										
										
										
											2020-06-04 21:27:29 +02:00
										 |  |  |     persistent_load.storage = storage | 
					
						
							| 
									
										
										
										
											2020-04-24 05:29:02 +02:00
										 |  |  |     return storage | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 18:02:15 +02:00
										 |  |  | def get_adjuster_settings(romfile: str) -> typing.Tuple[str, bool]: | 
					
						
							| 
									
										
										
										
											2020-06-07 12:04:33 -07:00
										 |  |  |     if hasattr(get_adjuster_settings, "adjuster_settings"): | 
					
						
							|  |  |  |         adjuster_settings = getattr(get_adjuster_settings, "adjuster_settings") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         adjuster_settings = persistent_load().get("adjuster", {}).get("last_settings", {}) | 
					
						
							|  |  |  |     if adjuster_settings: | 
					
						
							|  |  |  |         import pprint | 
					
						
							|  |  |  |         import Patch | 
					
						
							|  |  |  |         adjuster_settings.rom = romfile | 
					
						
							|  |  |  |         adjuster_settings.baserom = Patch.get_base_rom_path() | 
					
						
							|  |  |  |         whitelist = {"disablemusic", "fastmenu", "heartbeep", "heartcolor", "ow_palettes", "quickswap", | 
					
						
							|  |  |  |                      "uw_palettes"} | 
					
						
							|  |  |  |         printed_options = {name: value for name, value in vars(adjuster_settings).items() if name in whitelist} | 
					
						
							|  |  |  |         sprite = getattr(adjuster_settings, "sprite", None) | 
					
						
							|  |  |  |         if sprite: | 
					
						
							|  |  |  |             printed_options["sprite"] = adjuster_settings.sprite.name | 
					
						
							|  |  |  |         if hasattr(get_adjuster_settings, "adjust_wanted"): | 
					
						
							|  |  |  |             adjust_wanted = getattr(get_adjuster_settings, "adjust_wanted") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             adjust_wanted = input(f"Last used adjuster settings were found. Would you like to apply these? \n" | 
					
						
							|  |  |  |                                   f"{pprint.pformat(printed_options)}\n" | 
					
						
							|  |  |  |                                   f"Enter yes or no: ") | 
					
						
							|  |  |  |         if adjust_wanted and adjust_wanted.startswith("y"): | 
					
						
							|  |  |  |             adjusted = True | 
					
						
							|  |  |  |             import AdjusterMain | 
					
						
							|  |  |  |             _, romfile = AdjusterMain.adjust(adjuster_settings) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             adjusted = False | 
					
						
							|  |  |  |             import logging | 
					
						
							|  |  |  |             if not hasattr(get_adjuster_settings, "adjust_wanted"): | 
					
						
							|  |  |  |                 logging.info(f"Skipping post-patch adjustment") | 
					
						
							|  |  |  |         get_adjuster_settings.adjuster_settings = adjuster_settings | 
					
						
							|  |  |  |         get_adjuster_settings.adjust_wanted = adjust_wanted | 
					
						
							|  |  |  |         return romfile, adjusted | 
					
						
							| 
									
										
										
										
											2020-06-09 18:02:15 +02:00
										 |  |  |     return romfile, False | 
					
						
							| 
									
										
										
										
											2020-06-07 12:04:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-14 20:22:42 +02:00
										 |  |  | class ReceivedItem(typing.NamedTuple): | 
					
						
							|  |  |  |     item: int | 
					
						
							|  |  |  |     location: int | 
					
						
							|  |  |  |     player: int | 
					
						
							| 
									
										
										
										
											2020-06-04 21:27:29 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_unique_identifier(): | 
					
						
							|  |  |  |     uuid = persistent_load().get("client", {}).get("uuid", None) | 
					
						
							|  |  |  |     if uuid: | 
					
						
							|  |  |  |         return uuid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     import uuid | 
					
						
							|  |  |  |     uuid = uuid.getnode() | 
					
						
							|  |  |  |     persistent_store("client", "uuid", uuid) | 
					
						
							|  |  |  |     return uuid | 
					
						
							| 
									
										
										
										
											2020-09-09 01:41:37 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | safe_builtins = { | 
					
						
							|  |  |  |     'set', | 
					
						
							|  |  |  |     'frozenset', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RestrictedUnpickler(pickle.Unpickler): | 
					
						
							|  |  |  |     def find_class(self, module, name): | 
					
						
							|  |  |  |         if module == "builtins" and name in safe_builtins: | 
					
						
							|  |  |  |             return getattr(builtins, name) | 
					
						
							|  |  |  |         # Forbid everything else. | 
					
						
							|  |  |  |         raise pickle.UnpicklingError("global '%s.%s' is forbidden" % | 
					
						
							|  |  |  |                                      (module, name)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def restricted_loads(s): | 
					
						
							|  |  |  |     """Helper function analogous to pickle.loads().""" | 
					
						
							|  |  |  |     return RestrictedUnpickler(io.BytesIO(s)).load() |