245 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			245 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import itertools
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from utils.utils import range_union
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# adapted from ips-util for python 3.2 (https://pypi.org/project/ips-util/)
							 | 
						||
| 
								 | 
							
								class IPS_Patch(object):
							 | 
						||
| 
								 | 
							
								    def __init__(self, patchDict=None):
							 | 
						||
| 
								 | 
							
								        self.records = []
							 | 
						||
| 
								 | 
							
								        self.truncate_length = None
							 | 
						||
| 
								 | 
							
								        self.max_size = 0
							 | 
						||
| 
								 | 
							
								        if patchDict is not None:
							 | 
						||
| 
								 | 
							
								            for addr, data in patchDict.items():
							 | 
						||
| 
								 | 
							
								                byteData = bytearray(data)
							 | 
						||
| 
								 | 
							
								                self.add_record(addr, byteData)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def toDict(self):
							 | 
						||
| 
								 | 
							
								        ret = {}
							 | 
						||
| 
								 | 
							
								        for record in self.records:
							 | 
						||
| 
								 | 
							
								            if 'rle_count' in record:
							 | 
						||
| 
								 | 
							
								                ret[record['address']] = [int.from_bytes(record['data'],'little')]*record['rle_count']
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                ret[record['address']] = [int(b) for b in record['data']]
							 | 
						||
| 
								 | 
							
								        return ret
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def load(filename):
							 | 
						||
| 
								 | 
							
								        loaded_patch = IPS_Patch()
							 | 
						||
| 
								 | 
							
								        with open(filename, 'rb') as file:
							 | 
						||
| 
								 | 
							
								            header = file.read(5)
							 | 
						||
| 
								 | 
							
								            if header != b'PATCH':
							 | 
						||
| 
								 | 
							
								                raise Exception('Not a valid IPS patch file!')
							 | 
						||
| 
								 | 
							
								            while True:
							 | 
						||
| 
								 | 
							
								                address_bytes = file.read(3)
							 | 
						||
| 
								 | 
							
								                if address_bytes == b'EOF':
							 | 
						||
| 
								 | 
							
								                    break
							 | 
						||
| 
								 | 
							
								                address = int.from_bytes(address_bytes, byteorder='big')
							 | 
						||
| 
								 | 
							
								                length = int.from_bytes(file.read(2), byteorder='big')
							 | 
						||
| 
								 | 
							
								                rle_count = 0
							 | 
						||
| 
								 | 
							
								                if length == 0:
							 | 
						||
| 
								 | 
							
								                    rle_count = int.from_bytes(file.read(2), byteorder='big')
							 | 
						||
| 
								 | 
							
								                    length = 1
							 | 
						||
| 
								 | 
							
								                data = file.read(length)
							 | 
						||
| 
								 | 
							
								                if rle_count > 0:
							 | 
						||
| 
								 | 
							
								                    loaded_patch.add_rle_record(address, data, rle_count)
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    loaded_patch.add_record(address, data)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            truncate_bytes = file.read(3)
							 | 
						||
| 
								 | 
							
								            if len(truncate_bytes) == 3:
							 | 
						||
| 
								 | 
							
								                loaded_patch.set_truncate_length(int.from_bytes(truncate_bytes, byteorder='big'))
							 | 
						||
| 
								 | 
							
								        return loaded_patch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def create(original_data, patched_data):
							 | 
						||
| 
								 | 
							
								        # The heuristics for optimizing a patch were chosen with reference to
							 | 
						||
| 
								 | 
							
								        # the source code of Flips: https://github.com/Alcaro/Flips
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        patch = IPS_Patch()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        run_in_progress = False
							 | 
						||
| 
								 | 
							
								        current_run_start = 0
							 | 
						||
| 
								 | 
							
								        current_run_data = bytearray()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        runs = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if len(original_data) > len(patched_data):
							 | 
						||
| 
								 | 
							
								            patch.set_truncate_length(len(patched_data))
							 | 
						||
| 
								 | 
							
								            original_data = original_data[:len(patched_data)]
							 | 
						||
| 
								 | 
							
								        elif len(original_data) < len(patched_data):
							 | 
						||
| 
								 | 
							
								            original_data += bytes([0] * (len(patched_data) - len(original_data)))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if original_data[-1] == 0 and patched_data[-1] == 0:
							 | 
						||
| 
								 | 
							
								                patch.add_record(len(patched_data) - 1, bytes([0]))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for index, (original, patched) in enumerate(zip(original_data, patched_data)):
							 | 
						||
| 
								 | 
							
								            if not run_in_progress:
							 | 
						||
| 
								 | 
							
								                if original != patched:
							 | 
						||
| 
								 | 
							
								                    run_in_progress = True
							 | 
						||
| 
								 | 
							
								                    current_run_start = index
							 | 
						||
| 
								 | 
							
								                    current_run_data = bytearray([patched])
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                if original == patched:
							 | 
						||
| 
								 | 
							
								                    runs.append((current_run_start, current_run_data))
							 | 
						||
| 
								 | 
							
								                    run_in_progress = False
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    current_run_data.append(patched)
							 | 
						||
| 
								 | 
							
								        if run_in_progress:
							 | 
						||
| 
								 | 
							
								            runs.append((current_run_start, current_run_data))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for start, data in runs:
							 | 
						||
| 
								 | 
							
								            if start == int.from_bytes(b'EOF', byteorder='big'):
							 | 
						||
| 
								 | 
							
								                start -= 1
							 | 
						||
| 
								 | 
							
								                data = bytes([patched_data[start - 1]]) + data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            grouped_byte_data = list([
							 | 
						||
| 
								 | 
							
								                {'val': key, 'count': sum(1 for _ in group), 'is_last': False}
							 | 
						||
| 
								 | 
							
								                for key,group in itertools.groupby(data)
							 | 
						||
| 
								 | 
							
								            ])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            grouped_byte_data[-1]['is_last'] = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            record_in_progress = bytearray()
							 | 
						||
| 
								 | 
							
								            pos = start
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            for group in grouped_byte_data:
							 | 
						||
| 
								 | 
							
								                if len(record_in_progress) > 0:
							 | 
						||
| 
								 | 
							
								                    # We don't want to interrupt a record in progress with a new header unless
							 | 
						||
| 
								 | 
							
								                    # this group is longer than two complete headers.
							 | 
						||
| 
								 | 
							
								                    if group['count'] > 13:
							 | 
						||
| 
								 | 
							
								                        patch.add_record(pos, record_in_progress)
							 | 
						||
| 
								 | 
							
								                        pos += len(record_in_progress)
							 | 
						||
| 
								 | 
							
								                        record_in_progress = bytearray()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                        patch.add_rle_record(pos, bytes([group['val']]), group['count'])
							 | 
						||
| 
								 | 
							
								                        pos += group['count']
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        record_in_progress += bytes([group['val']] * group['count'])
							 | 
						||
| 
								 | 
							
								                elif (group['count'] > 3 and group['is_last']) or group['count'] > 8:
							 | 
						||
| 
								 | 
							
								                    # We benefit from making this an RLE record if the length is at least 8,
							 | 
						||
| 
								 | 
							
								                    # or the length is at least 3 and we know it to be the last part of this diff.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    # Make sure not to overflow the maximum length. Split it up if necessary.
							 | 
						||
| 
								 | 
							
								                    remaining_length = group['count']
							 | 
						||
| 
								 | 
							
								                    while remaining_length > 0xffff:
							 | 
						||
| 
								 | 
							
								                        patch.add_rle_record(pos, bytes([group['val']]), 0xffff)
							 | 
						||
| 
								 | 
							
								                        remaining_length -= 0xffff
							 | 
						||
| 
								 | 
							
								                        pos += 0xffff
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    patch.add_rle_record(pos, bytes([group['val']]), remaining_length)
							 | 
						||
| 
								 | 
							
								                    pos += remaining_length
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    # Just begin a new standard record.
							 | 
						||
| 
								 | 
							
								                    record_in_progress += bytes([group['val']] * group['count'])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                if len(record_in_progress) > 0xffff:
							 | 
						||
| 
								 | 
							
								                    patch.add_record(pos, record_in_progress[:0xffff])
							 | 
						||
| 
								 | 
							
								                    record_in_progress = record_in_progress[0xffff:]
							 | 
						||
| 
								 | 
							
								                    pos += 0xffff
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # Finalize any record still in progress.
							 | 
						||
| 
								 | 
							
								            if len(record_in_progress) > 0:
							 | 
						||
| 
								 | 
							
								                patch.add_record(pos, record_in_progress)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return patch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def add_record(self, address, data):
							 | 
						||
| 
								 | 
							
								        if address == int.from_bytes(b'EOF', byteorder='big'):
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Start address {0:x} is invalid in the IPS format. Please shift your starting address back by one byte to avoid it.'.format(address))
							 | 
						||
| 
								 | 
							
								        if address > 0xffffff:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Start address {0:x} is too large for the IPS format. Addresses must fit into 3 bytes.'.format(address))
							 | 
						||
| 
								 | 
							
								        if len(data) > 0xffff:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Record with length {0} is too large for the IPS format. Records must be less than 65536 bytes.'.format(len(data)))
							 | 
						||
| 
								 | 
							
								        if len(data) == 0: # ignore empty records
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								        record = {'address': address, 'data': data, 'size':len(data)}
							 | 
						||
| 
								 | 
							
								        self.appendRecord(record)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def add_rle_record(self, address, data, count):
							 | 
						||
| 
								 | 
							
								        if address == int.from_bytes(b'EOF', byteorder='big'):
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Start address {0:x} is invalid in the IPS format. Please shift your starting address back by one byte to avoid it.'.format(address))
							 | 
						||
| 
								 | 
							
								        if address > 0xffffff:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Start address {0:x} is too large for the IPS format. Addresses must fit into 3 bytes.'.format(address))
							 | 
						||
| 
								 | 
							
								        if count > 0xffff:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('RLE record with length {0} is too large for the IPS format. RLE records must be less than 65536 bytes.'.format(count))
							 | 
						||
| 
								 | 
							
								        if len(data) != 1:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError('Data for RLE record must be exactly one byte! Received {0}.'.format(data))
							 | 
						||
| 
								 | 
							
								        record = {'address': address, 'data': data, 'rle_count': count, 'size': count}
							 | 
						||
| 
								 | 
							
								        self.appendRecord(record)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def appendRecord(self, record):
							 | 
						||
| 
								 | 
							
								        sz = record['address'] + record['size']
							 | 
						||
| 
								 | 
							
								        if sz > self.max_size:
							 | 
						||
| 
								 | 
							
								            self.max_size = sz
							 | 
						||
| 
								 | 
							
								        self.records.append(record)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def set_truncate_length(self, truncate_length):
							 | 
						||
| 
								 | 
							
								        self.truncate_length = truncate_length
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def encode(self):
							 | 
						||
| 
								 | 
							
								        encoded_bytes = bytearray()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        encoded_bytes += 'PATCH'.encode('ascii')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for record in self.records:
							 | 
						||
| 
								 | 
							
								            encoded_bytes += record['address'].to_bytes(3, byteorder='big')
							 | 
						||
| 
								 | 
							
								            if 'rle_count' in record:
							 | 
						||
| 
								 | 
							
								                encoded_bytes += (0).to_bytes(2, byteorder='big')
							 | 
						||
| 
								 | 
							
								                encoded_bytes += record['rle_count'].to_bytes(2, byteorder='big')
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                encoded_bytes += len(record['data']).to_bytes(2, byteorder='big')
							 | 
						||
| 
								 | 
							
								            encoded_bytes += record['data']
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        encoded_bytes += 'EOF'.encode('ascii')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.truncate_length is not None:
							 | 
						||
| 
								 | 
							
								            encoded_bytes += self.truncate_length.to_bytes(3, byteorder='big')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return encoded_bytes
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # save patch into IPS file
							 | 
						||
| 
								 | 
							
								    def save(self, path):
							 | 
						||
| 
								 | 
							
								        with open(path, 'wb') as ipsFile:
							 | 
						||
| 
								 | 
							
								            ipsFile.write(self.encode())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # applies patch on an existing bytearray
							 | 
						||
| 
								 | 
							
								    def apply(self, in_data):
							 | 
						||
| 
								 | 
							
								        out_data = bytearray(in_data)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for record in self.records:
							 | 
						||
| 
								 | 
							
								            if record['address'] >= len(out_data):
							 | 
						||
| 
								 | 
							
								                out_data += bytes([0] * (record['address'] - len(out_data) + 1))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if 'rle_count' in record:
							 | 
						||
| 
								 | 
							
								                out_data[record['address'] : record['address'] + record['rle_count']] = b''.join([record['data']] * record['rle_count'])
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                out_data[record['address'] : record['address'] + len(record['data'])] = record['data']
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.truncate_length is not None:
							 | 
						||
| 
								 | 
							
								            out_data = out_data[:self.truncate_length]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return out_data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # applies patch on an opened file
							 | 
						||
| 
								 | 
							
								    def applyFile(self, handle):
							 | 
						||
| 
								 | 
							
								        for record in self.records:
							 | 
						||
| 
								 | 
							
								            handle.seek(record['address'])
							 | 
						||
| 
								 | 
							
								            if 'rle_count' in record:
							 | 
						||
| 
								 | 
							
								                handle.write(bytearray(b'').join([record['data']]) * record['rle_count'])
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                handle.write(record['data'])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # appends an IPS_Patch on top of this one
							 | 
						||
| 
								 | 
							
								    def append(self, patch):
							 | 
						||
| 
								 | 
							
								        if patch.truncate_length is not None and (self.truncate_length is None or patch.truncate_length > self.truncate_length):
							 | 
						||
| 
								 | 
							
								            self.set_truncate_length(patch.truncate_length)
							 | 
						||
| 
								 | 
							
								        for record in patch.records:
							 | 
						||
| 
								 | 
							
								            if record['size'] > 0: # ignore empty records
							 | 
						||
| 
								 | 
							
								                self.appendRecord(record)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # gets address ranges written to by this patch
							 | 
						||
| 
								 | 
							
								    def getRanges(self):
							 | 
						||
| 
								 | 
							
								        def getRange(record):
							 | 
						||
| 
								 | 
							
								            return range(record['address'], record['address']+record['size'])
							 | 
						||
| 
								 | 
							
								        return range_union([getRange(record) for record in self.records])
							 |