_solver.py
1 # Copyright (C) 2025 Armin "Era" Ramezani <e@4d2.org> 2 # 3 # This file is a part of patchman. 4 # 5 # patchman is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public 6 # License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later 7 # version. 8 # 9 # patchman is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied 10 # warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 11 # details. 12 # 13 # You should have received a copy of the GNU Lesser General Public License along with patchman. If not, see 14 # <https://www.gnu.org/licenses/>. 15 # 16 """Diff functionality.""" 17 18 import math 19 20 try: 21 from typing import TYPE_CHECKING 22 except ImportError: 23 TYPE_CHECKING = False 24 25 if TYPE_CHECKING: 26 from typing import Any, List, Literal, Optional, Sequence, Tuple, Union 27 28 29 30 class ConflictError(Exception): 31 pass 32 33 34 class OperationalList(object): 35 def __init__(self, seq = ()): # type: (Sequence) -> None 36 self._list = list(seq) 37 38 def __add__(self, other): # type: (OperationalList) -> OperationalList 39 assert isinstance(other, OperationalList) 40 return OperationalList(self._list + other._list) 41 42 def __contains__(self, item): # type: (OperationalList) -> bool 43 return self.find(item) != -1 44 45 def __eq__(self, other): # type: (object) -> bool 46 if isinstance(other, OperationalList): 47 return self._list == other._list 48 return self._list == other 49 50 def __getitem__(self, item): # type: (Union[slice, int]) -> Any 51 if isinstance(item, int): 52 return self._list[item] 53 return self.__class__(self._list[item]) 54 55 def __len__(self): # type: () -> int 56 return len(self._list) 57 58 def __repr__(self): # type: () -> str 59 return repr(self._list) 60 61 def __setitem__(self, key, value): # type: (int, Any) -> None 62 self._list[key] = value 63 64 def find(self, sub, start = 0): # type: (OperationalList, int) -> int 65 if start < 0: 66 start = len(self) + start 67 for i in range(start, len(self) - len(sub) + 1): 68 if self[i:i + len(sub)] == sub: 69 return i 70 return -1 71 72 73 74 class Diff(object): 75 def __init__(self, addition, pointer, content): # type: (bool, int, Union[bytes, OperationalList]) -> None 76 self.addition = addition 77 self.pointer = pointer 78 self.content = content 79 80 def __add__(self, other): # type: (Diff) -> Diff 81 if not ( 82 self.addition is other.addition 83 and ((self.pointer == other.pointer) if self.addition else (self.pointer + len(self) == other.pointer)) 84 ): 85 raise ConflictError('Diff addition failed.') 86 return self.__class__(self.addition, self.pointer, self.content + other.content) 87 88 def __eq__(self, other): # type: (Diff) -> bool 89 return self.addition == other.addition and self.pointer == other.pointer and self.content == other.content 90 91 def __len__(self): # type: () -> int 92 return len(self.content) 93 94 def __repr__(self): # type: () -> str 95 return repr([self.addition, self.pointer, self.content]) 96 97 def __sub__(self, other): # type: (Diff) -> Diff 98 if self.addition is not other.addition: 99 raise ConflictError('Diff subtraction failed.') 100 if self.addition: 101 if self.pointer == other.pointer: 102 i = min(len(self), len(other)) 103 while i != 0: 104 if self.content[-i:] == other.content[:i]: 105 break 106 i -= 1 107 return self.__class__(self.addition, self.pointer, self.content[:len(self) - i]) 108 else: 109 i = min(len(self), len(other)) 110 while i != 0: 111 if self.content[-i:] == other.content[:i] and self.pointer + len(self) - i == other.pointer: 112 break 113 i -= 1 114 return self.__class__(self.addition, self.pointer, self.content[:len(self) - i]) 115 return self.__class__(self.addition, self.pointer, self.content) 116 117 def __contains__(self, item): # type: (Diff) -> bool 118 if self.addition is not item.addition: 119 return False 120 if self.addition: 121 if item.content in self.content and self.pointer == item.pointer: 122 return True 123 return False 124 indexes = [] 125 j = 0 126 while j < len(self): 127 k = self.content.find(item.content, j) 128 if k == -1: 129 break 130 indexes.append(k) 131 j = k + 1 132 return any(self.pointer + i == item.pointer for i in indexes) 133 134 135 def _udiff( 136 base, # type: Union[bytes, OperationalList] 137 new, # type: Union[bytes, OperationalList] 138 size_override = None, # type: Optional[int] 139 ): # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes, int, OperationalList]]] 140 base_len = len(base) 141 new_len = len(new) 142 if size_override is None: 143 size = min(base_len, new_len) 144 else: 145 size = min(base_len, new_len, size_override) 146 while size != 0: 147 matches = [] 148 for i in range(base_len - size + 1): 149 # FIXME: its probably faster to use regex here 150 j = 0 151 while j < new_len: 152 k = new.find(base[i:i + size], j) 153 if k == -1: 154 break 155 matches.append([[i, i + size], [k, k + size]]) 156 j = k + 1 157 penalty = deletions = additions = None 158 for candidate in matches: 159 ldeletions, ladditions = _udiff(base[:candidate[0][0]], new[:candidate[1][0]], size) 160 rdeletions, radditions = _udiff(base[candidate[0][1]:], new[candidate[1][1]:], size) 161 temp_deletions = ldeletions + [candidate[0][1] - candidate[0][0]] + rdeletions 162 temp_additions = ladditions + [candidate[1][1] - candidate[1][0]] + radditions 163 temp_penalty = 0 164 for j in temp_deletions + temp_additions: 165 if isinstance(j, bytes): 166 temp_penalty += len(j) 167 if penalty is None or temp_penalty < penalty: 168 deletions = temp_deletions 169 additions = temp_additions 170 penalty = temp_penalty 171 if penalty is not None: 172 return deletions, additions 173 size -= 1 174 return [base], [new, base_len] 175 176 177 def _greedy_udiff( 178 base, # type: Union[bytes, OperationalList] 179 new, # type: Union[bytes, OperationalList] 180 size_override = None, # type: Optional[int] 181 ): # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes, int, OperationalList]]] 182 base_len = len(base) 183 new_len = len(new) 184 if size_override is None: 185 size = min(base_len, new_len) 186 else: 187 size = min(base_len, new_len, size_override) 188 while size != 0: 189 for i in range(base_len - size + 1): 190 j = new.find(base[i:i + size]) 191 if j != -1: 192 ldeletions, ladditions = _greedy_udiff(base[:i], new[:j], size) 193 rdeletions, radditions = _greedy_udiff(base[i + size:], new[j + size:], size) 194 return ldeletions + [size] + rdeletions, ladditions + [size] + radditions 195 size -= 1 196 return [base], [new, base_len] 197 198 199 def _greedy_smartass_diff( 200 base, # type: Union[bytes, OperationalList] 201 new, # type: Union[bytes, OperationalList] 202 max_size_override = None, # type: Optional[int] 203 ): # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes, int, OperationalList]]] 204 base_len = len(base) 205 new_len = len(new) 206 size_range = [0, min(base_len, new_len) if max_size_override is None else min(base_len, new_len, max_size_override)] 207 candidates = {} 208 while True: 209 candidate = int(math.ceil((size_range[0] + size_range[1]) / 2)) 210 if candidate == 0: 211 return [base], [new, base_len] 212 if candidate in candidates.keys(): 213 if candidates[candidate] is None: 214 return [base], [new, base_len] 215 ldeletions, ladditions = _greedy_smartass_diff( 216 base[:candidates[candidate][0]], new[:candidates[candidate][1]], size_range[1] 217 ) 218 rdeletions, radditions = _greedy_smartass_diff( 219 base[candidates[candidate][0] + candidate:], new[candidates[candidate][1] + candidate:], size_range[1] 220 ) 221 return ldeletions + [candidate] + rdeletions, ladditions + [candidate] + radditions 222 for i in range(base_len - candidate + 1): 223 j = new.find(base[i:i + candidate]) 224 if j != -1: 225 candidates[candidate] = [i, j] 226 size_range[0] = candidate 227 break 228 if size_range[0] != candidate: 229 candidates[candidate] = None 230 size_range[1] = candidate - 1 231 232 233 def _lazy_lumberjack_diff( 234 base, # type: Union[bytes, OperationalList] 235 new, # type: Union[bytes, OperationalList] 236 size_override = None, # type: Optional[int] 237 ): # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes, int, OperationalList]]] 238 base_len = len(base) 239 new_len = len(new) 240 size = min(base_len, new_len) if size_override is None else min(base_len, new_len, size_override) 241 while True: 242 if size == 0: 243 return [base], [new, base_len] 244 candidates = {} 245 skip_list = [] 246 for i in range(base_len - size + 1): 247 j = -1 248 while j < new_len - size: 249 j = new.find(base[i:i + size], j + 1) 250 if j == -1 or [i, j] in skip_list: 251 break 252 skip_list.append([i, j]) 253 push = 0 254 while ( 255 i + size + push < base_len 256 and j + size + push < new_len 257 and base[i + size + push] == new[j + size + push] 258 ): 259 push += 1 260 skip_list.append([i + push, j + push]) 261 shoe = size + push 262 candidates.setdefault(shoe, []) 263 candidates[shoe].append([i, j, shoe]) 264 if candidates: 265 selections = {} 266 for k in sorted(candidates.keys(), reverse=True): 267 for selection in selections.values(): 268 marked = [] 269 for candidate in candidates[k]: 270 if ( 271 (candidate[0] < selection[0] + selection[2] or candidate[1] < selection[1] + selection[2]) 272 and (candidate[0] + k > selection[0] or candidate[1] + k > selection[1]) 273 ): 274 marked.append(candidate) 275 for m in marked: 276 candidates[k].remove(m) 277 i = 0 278 while i < len(candidates[k]) - 1: 279 j = i + 1 280 while j < len(candidates[k]): 281 if ( 282 candidates[k][i][0] <= candidates[k][j][0] < candidates[k][i][0] + k 283 or candidates[k][i][0] - k < candidates[k][j][0] <= candidates[k][i][0] 284 or candidates[k][i][1] <= candidates[k][j][1] < candidates[k][i][1] + k 285 or candidates[k][i][1] - k < candidates[k][j][1] <= candidates[k][i][1] 286 ): 287 if ( 288 ( 289 (candidates[k][i][0] - candidates[k][i][1]) ** 2 290 + (base_len - candidates[k][i][0] - new_len + candidates[k][i][1]) ** 2 291 ) 292 > ( 293 (candidates[k][j][0] - candidates[k][j][1]) ** 2 294 + (base_len - candidates[k][j][0] - new_len + candidates[k][j][1]) ** 2 295 ) 296 ): 297 candidates[k][i] = candidates[k][j] 298 j = i + 1 299 candidates[k].pop(j) 300 else: 301 j += 1 302 i += 1 303 for candidate in candidates[k]: 304 selections[candidate[0]] = candidate 305 selections = [selections[k] for k in sorted(selections.keys())] 306 deletions, additions = _lazy_lumberjack_diff(base[:selections[0][0]], new[:selections[0][1]], size - 1) 307 deletions.append(selections[0][2]) 308 additions.append(selections[0][2]) 309 for i in range(len(selections) - 1): 310 tdeletions, tadditions = _lazy_lumberjack_diff( 311 base[selections[i][0] + selections[i][2]:selections[i + 1][0]], 312 new[selections[i][1] + selections[i][2]:selections[i + 1][1]], 313 size - 1, 314 ) 315 deletions.extend(tdeletions + [selections[i + 1][2]]) 316 additions.extend(tadditions + [selections[i + 1][2]]) 317 tdeletions, tadditions = _lazy_lumberjack_diff( 318 base[selections[-1][0] + selections[-1][2]:], new[selections[-1][1] + selections[-1][2]:], size - 1 319 ) 320 return deletions + tdeletions, additions + tadditions 321 size = size // 2 322 323 324 def _quick_diff( 325 base, # type: Union[bytes, OperationalList] 326 new, # type: Union[bytes, OperationalList] 327 ): # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes, int, OperationalList]]] 328 deletions = [] 329 additions = [] 330 while True: 331 if len(base) == 0: 332 additions.append(new) 333 break 334 if len(new) == 0: 335 deletions.append(base) 336 additions.append(len(base)) 337 break 338 i = 0 339 while i < len(base) and i < len(new) and base[i] == new[i]: 340 i += 1 341 if i: 342 base = base[i:] 343 new = new[i:] 344 deletions.append(i) 345 additions.append(i) 346 else: 347 depth = 1 348 while True: 349 if depth == len(base) + len(new) - 1: 350 return deletions + [base], additions + [new, len(base)] 351 for i in range(depth + 1): 352 if depth - i >= len(base) or i >= len(new): 353 continue 354 if base[depth - i] == new[i]: 355 deletions.append(base[:depth - i]) 356 additions.append(new[:i]) 357 additions.append(depth - i) 358 base = base[depth - i:] 359 new = new[i:] 360 depth = None 361 break 362 if not depth: 363 break 364 depth += 1 365 return deletions, additions 366 367 368 DIFF_ALGORITHMS = { 369 'quick': _quick_diff, 370 'lazy-lumberjack': _lazy_lumberjack_diff, 371 'greedy-smartass': _greedy_smartass_diff, 372 'greedy-u': _greedy_udiff, 373 'u': _udiff, 374 } 375 376 377 def byte_diff( 378 base, # type: bytes 379 new, # type: bytes 380 split = None, # type: Optional[bytes] 381 algorithm = 'lazy-lumberjack', # type: Literal['quick', 'lazy-lumberjack', 'greedy-smartass', 'greedy-u', 'u'] 382 ): # type: (...) -> List[Diff] 383 if split is not None: 384 base = OperationalList(base.split(split)) 385 for i in range(len(base) - 1): 386 base[i] += split 387 new = OperationalList(new.split(split)) 388 for i in range(len(new) - 1): 389 new[i] += split 390 deletions, additions = DIFF_ALGORITHMS[algorithm](base, new) 391 diffs = [] 392 pointer = 0 393 append_to_last = False 394 for v in deletions: 395 if isinstance(v, bytes) or isinstance(v, OperationalList): 396 if append_to_last: 397 diffs[-1].content += v 398 pointer += len(v) 399 elif len(v): 400 diffs.append(Diff(False, pointer, v)) 401 pointer += len(v) 402 append_to_last = True 403 elif v: 404 pointer += v 405 append_to_last = False 406 pointer = 0 407 append_to_last = False 408 for v in additions: 409 if isinstance(v, bytes) or isinstance(v, OperationalList): 410 if append_to_last: 411 diffs[-1].content += v 412 elif len(v): 413 diffs.append(Diff(True, pointer, v)) 414 append_to_last = True 415 elif v: 416 pointer += v 417 append_to_last = False 418 return diffs 419 420 421 def byte_patch(base, diffs, split = None): # type: (bytes, List[Diff], Optional[bytes]) -> bytes 422 if split is not None: 423 base = OperationalList(base.split(split)) 424 for i in range(len(base) - 1): 425 base[i] += split 426 diffs = [Diff(diff.addition, diff.pointer, diff.content) for diff in diffs] 427 while len(diffs): 428 diff = diffs.pop(0) 429 content_len = len(diff.content) 430 if diff.addition: 431 base = base[:diff.pointer] + diff.content + base[diff.pointer:] 432 changes = content_len 433 else: 434 base = base[:diff.pointer] + base[diff.pointer + content_len:] 435 changes = -content_len 436 for diff2 in diffs: 437 if diff2.pointer > diff.pointer or diff.addition and diff2.pointer == diff.pointer: 438 diff2.pointer += changes 439 if isinstance(base, bytes): 440 return base 441 byte_base = bytes() 442 for v in base: 443 byte_base += v 444 return byte_base 445 446 447 def unify_diffs(our_diff, their_diff): # type: (List[Diff], List[Diff]) -> List[Diff] 448 sequences = list(our_diff) 449 sequences.extend(their_diff) 450 i = 0 451 while i < len(sequences) - 1: 452 j = i + 1 453 while j < len(sequences): 454 if ( 455 ( 456 not ( 457 sequences[i].pointer <= sequences[j].pointer < sequences[i].pointer + len(sequences[i]) 458 or sequences[j].pointer <= sequences[i].pointer < sequences[j].pointer + len(sequences[j]) 459 ) 460 ) 461 if not sequences[i].addition and not sequences[j].addition 462 else ( 463 (sequences[i].pointer != sequences[j].pointer) 464 or ( 465 sequences[i].addition 466 and not sequences[j].addition 467 and sequences[i].pointer == sequences[j].pointer 468 ) 469 ) 470 ): # Checking for conflict (or lack thereof). 471 j += 1 472 elif ( 473 not sequences[i].addition and sequences[j].addition and sequences[i].pointer == sequences[j].pointer 474 ): # Checking to see if we can swap the order of our conflicting diffs in order to solve the conflict 475 sequences.insert(j, sequences[i]) 476 sequences[i] = sequences.pop(j + 1) 477 j = i + 1 478 elif sequences[j] in sequences[i]: # Checking whether diff i contains diff j or is equal to it. 479 sequences.pop(j) 480 elif sequences[i] in sequences[j]: # Checking whether diff j contains diff i. 481 sequences.pop(i) 482 j = i + 1 483 else: # last resort, hoping to find content overlap between our conflicting diffs 484 try: 485 temp = sequences[i] - sequences[j] 486 if len(temp) == len(sequences[i]): 487 raise ConflictError('Couldn\'t unify deltas for three-way merge') 488 sequences[i] = temp + sequences[j] 489 sequences.pop(j) 490 except ConflictError: 491 temp = sequences[j] - sequences[i] 492 if len(temp) == len(sequences[j]): 493 print(sequences[i], sequences[j]) 494 raise ConflictError('Couldn\'t unify deltas for three-way merge') 495 sequences[j] = temp + sequences[i] 496 sequences.pop(i) 497 j = i + 1 498 i += 1 499 return sequences