/ src / patchman / _solver.py
_solver.py
  1  # Copyright (C) 2025  Armin "Era" Ramezani <e@4d2.org>
  2  #
  3  # This file is a part of patchman.
  4  #
  5  # patchman is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
  6  # License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
  7  # version.
  8  #
  9  # patchman is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
 10  # warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
 11  # details.
 12  #
 13  # You should have received a copy of the GNU Lesser General Public License along with patchman.  If not, see
 14  # <https://www.gnu.org/licenses/>.
 15  #
 16  """Diff functionality."""
 17  
 18  import math
 19  
 20  try:
 21      from typing import TYPE_CHECKING
 22  except ImportError:
 23      TYPE_CHECKING = False
 24  
 25  if TYPE_CHECKING:
 26      from typing import Any, List, Literal, Optional, Sequence, Tuple, Union
 27  
 28  
 29  
 30  class ConflictError(Exception):
 31      pass
 32  
 33  
 34  class OperationalList(object):
 35      def __init__(self, seq = ()):  # type: (Sequence) -> None
 36          self._list = list(seq)
 37  
 38      def __add__(self, other):  # type: (OperationalList) -> OperationalList
 39          assert isinstance(other, OperationalList)
 40          return OperationalList(self._list + other._list)
 41  
 42      def __contains__(self, item):  # type: (OperationalList) -> bool
 43          return self.find(item) != -1
 44  
 45      def __eq__(self, other):  # type: (object) -> bool
 46          if isinstance(other, OperationalList):
 47              return self._list == other._list
 48          return self._list == other
 49  
 50      def __getitem__(self, item):  # type: (Union[slice, int]) -> Any
 51          if isinstance(item, int):
 52              return self._list[item]
 53          return self.__class__(self._list[item])
 54  
 55      def __len__(self):  # type: () -> int
 56          return len(self._list)
 57  
 58      def __repr__(self):  # type: () -> str
 59          return repr(self._list)
 60  
 61      def __setitem__(self, key, value):  # type: (int, Any) -> None
 62          self._list[key] = value
 63  
 64      def find(self, sub, start = 0):  # type: (OperationalList, int) -> int
 65          if start < 0:
 66              start = len(self) + start
 67          for i in range(start, len(self) - len(sub) + 1):
 68              if self[i:i + len(sub)] == sub:
 69                  return i
 70          return -1
 71  
 72  
 73  
 74  class Diff(object):
 75      def __init__(self, addition, pointer, content):  # type: (bool, int, Union[bytes, OperationalList]) -> None
 76          self.addition = addition
 77          self.pointer = pointer
 78          self.content = content
 79  
 80      def __add__(self, other):  # type: (Diff) -> Diff
 81          if not (
 82              self.addition is other.addition
 83              and ((self.pointer == other.pointer) if self.addition else (self.pointer + len(self) == other.pointer))
 84          ):
 85              raise ConflictError('Diff addition failed.')
 86          return self.__class__(self.addition, self.pointer, self.content + other.content)
 87  
 88      def __eq__(self, other):  # type: (Diff) -> bool
 89          return self.addition == other.addition and self.pointer == other.pointer and self.content == other.content
 90  
 91      def __len__(self):  # type: () -> int
 92          return len(self.content)
 93  
 94      def __repr__(self):  # type: () -> str
 95          return repr([self.addition, self.pointer, self.content])
 96  
 97      def __sub__(self, other):  # type: (Diff) -> Diff
 98          if self.addition is not other.addition:
 99              raise ConflictError('Diff subtraction failed.')
100          if self.addition:
101              if self.pointer == other.pointer:
102                  i = min(len(self), len(other))
103                  while i != 0:
104                      if self.content[-i:] == other.content[:i]:
105                          break
106                      i -= 1
107                  return self.__class__(self.addition, self.pointer, self.content[:len(self) - i])
108          else:
109              i = min(len(self), len(other))
110              while i != 0:
111                  if self.content[-i:] == other.content[:i] and self.pointer + len(self) - i == other.pointer:
112                      break
113                  i -= 1
114              return self.__class__(self.addition, self.pointer, self.content[:len(self) - i])
115          return self.__class__(self.addition, self.pointer, self.content)
116  
117      def __contains__(self, item):  # type: (Diff) -> bool
118          if self.addition is not item.addition:
119              return False
120          if self.addition:
121              if item.content in self.content and self.pointer == item.pointer:
122                  return True
123              return False
124          indexes = []
125          j = 0
126          while j < len(self):
127              k = self.content.find(item.content, j)
128              if k == -1:
129                  break
130              indexes.append(k)
131              j = k + 1
132          return any(self.pointer + i == item.pointer for i in indexes)
133  
134  
135  def _udiff(
136          base,  # type: Union[bytes, OperationalList]
137          new,  # type: Union[bytes, OperationalList]
138          size_override = None,  # type: Optional[int]
139  ):  # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes,  int, OperationalList]]]
140      base_len = len(base)
141      new_len = len(new)
142      if size_override is None:
143          size = min(base_len, new_len)
144      else:
145          size = min(base_len, new_len, size_override)
146      while size != 0:
147          matches = []
148          for i in range(base_len - size + 1):
149              # FIXME: its probably faster to use regex here
150              j = 0
151              while j < new_len:
152                  k = new.find(base[i:i + size], j)
153                  if k == -1:
154                      break
155                  matches.append([[i, i + size], [k, k + size]])
156                  j = k + 1
157          penalty = deletions = additions = None
158          for candidate in matches:
159              ldeletions, ladditions = _udiff(base[:candidate[0][0]], new[:candidate[1][0]], size)
160              rdeletions, radditions = _udiff(base[candidate[0][1]:], new[candidate[1][1]:], size)
161              temp_deletions = ldeletions + [candidate[0][1] - candidate[0][0]] + rdeletions
162              temp_additions = ladditions + [candidate[1][1] - candidate[1][0]] + radditions
163              temp_penalty = 0
164              for j in temp_deletions + temp_additions:
165                  if isinstance(j, bytes):
166                      temp_penalty += len(j)
167              if penalty is None or temp_penalty < penalty:
168                  deletions = temp_deletions
169                  additions = temp_additions
170                  penalty = temp_penalty
171          if penalty is not None:
172              return deletions, additions
173          size -= 1
174      return [base], [new, base_len]
175  
176  
177  def _greedy_udiff(
178          base,  # type: Union[bytes, OperationalList]
179          new,  # type: Union[bytes, OperationalList]
180          size_override = None,  # type: Optional[int]
181  ):  # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes,  int, OperationalList]]]
182      base_len = len(base)
183      new_len = len(new)
184      if size_override is None:
185          size = min(base_len, new_len)
186      else:
187          size = min(base_len, new_len, size_override)
188      while size != 0:
189          for i in range(base_len - size + 1):
190              j = new.find(base[i:i + size])
191              if j != -1:
192                  ldeletions, ladditions = _greedy_udiff(base[:i], new[:j], size)
193                  rdeletions, radditions = _greedy_udiff(base[i + size:], new[j + size:], size)
194                  return ldeletions + [size] + rdeletions, ladditions + [size] + radditions
195          size -= 1
196      return [base], [new, base_len]
197  
198  
199  def _greedy_smartass_diff(
200          base,  # type: Union[bytes, OperationalList]
201          new,  # type: Union[bytes, OperationalList]
202          max_size_override = None,  # type: Optional[int]
203  ):  # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes,  int, OperationalList]]]
204      base_len = len(base)
205      new_len = len(new)
206      size_range = [0, min(base_len, new_len) if max_size_override is None else min(base_len, new_len, max_size_override)]
207      candidates = {}
208      while True:
209          candidate = int(math.ceil((size_range[0] + size_range[1]) / 2))
210          if candidate == 0:
211              return [base], [new, base_len]
212          if candidate in candidates.keys():
213              if candidates[candidate] is None:
214                  return [base], [new, base_len]
215              ldeletions, ladditions = _greedy_smartass_diff(
216                  base[:candidates[candidate][0]], new[:candidates[candidate][1]], size_range[1]
217              )
218              rdeletions, radditions = _greedy_smartass_diff(
219                  base[candidates[candidate][0] + candidate:], new[candidates[candidate][1] + candidate:], size_range[1]
220              )
221              return ldeletions + [candidate] + rdeletions, ladditions + [candidate] + radditions
222          for i in range(base_len - candidate + 1):
223              j = new.find(base[i:i + candidate])
224              if j != -1:
225                  candidates[candidate] = [i, j]
226                  size_range[0] = candidate
227                  break
228          if size_range[0] != candidate:
229              candidates[candidate] = None
230              size_range[1] = candidate - 1
231  
232  
233  def _lazy_lumberjack_diff(
234          base,  # type: Union[bytes, OperationalList]
235          new,  # type: Union[bytes, OperationalList]
236          size_override = None,  # type: Optional[int]
237  ):  # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes,  int, OperationalList]]]
238      base_len = len(base)
239      new_len = len(new)
240      size = min(base_len, new_len) if size_override is None else min(base_len, new_len, size_override)
241      while True:
242          if size == 0:
243              return [base], [new, base_len]
244          candidates = {}
245          skip_list = []
246          for i in range(base_len - size + 1):
247              j = -1
248              while j < new_len - size:
249                  j = new.find(base[i:i + size], j + 1)
250                  if j == -1 or [i, j] in skip_list:
251                      break
252                  skip_list.append([i, j])
253                  push = 0
254                  while (
255                      i + size + push < base_len
256                      and j + size + push < new_len
257                      and base[i + size + push] == new[j + size + push]
258                  ):
259                      push += 1
260                      skip_list.append([i + push, j + push])
261                  shoe = size + push
262                  candidates.setdefault(shoe, [])
263                  candidates[shoe].append([i, j, shoe])
264          if candidates:
265              selections = {}
266              for k in sorted(candidates.keys(), reverse=True):
267                  for selection in selections.values():
268                      marked = []
269                      for candidate in candidates[k]:
270                          if (
271                              (candidate[0] < selection[0] + selection[2] or candidate[1] < selection[1] + selection[2])
272                              and (candidate[0] + k > selection[0] or candidate[1] + k > selection[1])
273                          ):
274                              marked.append(candidate)
275                      for m in marked:
276                          candidates[k].remove(m)
277                  i = 0
278                  while i < len(candidates[k]) - 1:
279                      j = i + 1
280                      while j < len(candidates[k]):
281                          if (
282                              candidates[k][i][0] <= candidates[k][j][0] < candidates[k][i][0] + k
283                              or candidates[k][i][0] - k < candidates[k][j][0] <= candidates[k][i][0]
284                              or candidates[k][i][1] <= candidates[k][j][1] < candidates[k][i][1] + k
285                              or candidates[k][i][1] - k < candidates[k][j][1] <= candidates[k][i][1]
286                          ):
287                              if (
288                                  (
289                                      (candidates[k][i][0] - candidates[k][i][1]) ** 2
290                                      + (base_len - candidates[k][i][0] - new_len + candidates[k][i][1]) ** 2
291                                  )
292                                  > (
293                                      (candidates[k][j][0] - candidates[k][j][1]) ** 2
294                                      + (base_len - candidates[k][j][0] - new_len + candidates[k][j][1]) ** 2
295                                  )
296                              ):
297                                  candidates[k][i] = candidates[k][j]
298                                  j = i + 1
299                              candidates[k].pop(j)
300                          else:
301                              j += 1
302                      i += 1
303                  for candidate in candidates[k]:
304                      selections[candidate[0]] = candidate
305              selections = [selections[k] for k in sorted(selections.keys())]
306              deletions, additions = _lazy_lumberjack_diff(base[:selections[0][0]], new[:selections[0][1]], size - 1)
307              deletions.append(selections[0][2])
308              additions.append(selections[0][2])
309              for i in range(len(selections) - 1):
310                  tdeletions, tadditions = _lazy_lumberjack_diff(
311                      base[selections[i][0] + selections[i][2]:selections[i + 1][0]],
312                      new[selections[i][1] + selections[i][2]:selections[i + 1][1]],
313                      size - 1,
314                  )
315                  deletions.extend(tdeletions + [selections[i + 1][2]])
316                  additions.extend(tadditions + [selections[i + 1][2]])
317              tdeletions, tadditions = _lazy_lumberjack_diff(
318                  base[selections[-1][0] + selections[-1][2]:], new[selections[-1][1] + selections[-1][2]:], size - 1
319              )
320              return deletions + tdeletions, additions + tadditions
321          size = size // 2
322  
323  
324  def _quick_diff(
325          base,  # type: Union[bytes, OperationalList]
326          new,  # type: Union[bytes, OperationalList]
327  ):  # type: (...) -> Tuple[List[Union[bytes, int, OperationalList]], List[Union[bytes,  int, OperationalList]]]
328      deletions = []
329      additions = []
330      while True:
331          if len(base) == 0:
332              additions.append(new)
333              break
334          if len(new) == 0:
335              deletions.append(base)
336              additions.append(len(base))
337              break
338          i = 0
339          while i < len(base) and i < len(new) and base[i] == new[i]:
340              i += 1
341          if i:
342              base = base[i:]
343              new = new[i:]
344              deletions.append(i)
345              additions.append(i)
346          else:
347              depth = 1
348              while True:
349                  if depth == len(base) + len(new) - 1:
350                      return deletions + [base], additions + [new, len(base)]
351                  for i in range(depth + 1):
352                      if depth - i >= len(base) or i >= len(new):
353                          continue
354                      if base[depth - i] == new[i]:
355                          deletions.append(base[:depth - i])
356                          additions.append(new[:i])
357                          additions.append(depth - i)
358                          base = base[depth - i:]
359                          new = new[i:]
360                          depth = None
361                          break
362                  if not depth:
363                      break
364                  depth += 1
365      return deletions, additions
366  
367  
368  DIFF_ALGORITHMS = {
369      'quick': _quick_diff,
370      'lazy-lumberjack': _lazy_lumberjack_diff,
371      'greedy-smartass': _greedy_smartass_diff,
372      'greedy-u': _greedy_udiff,
373      'u': _udiff,
374  }
375  
376  
377  def byte_diff(
378          base,  # type: bytes
379          new,  # type: bytes
380          split = None,  # type: Optional[bytes]
381          algorithm = 'lazy-lumberjack',  # type: Literal['quick', 'lazy-lumberjack', 'greedy-smartass', 'greedy-u', 'u']
382  ):  # type: (...) -> List[Diff]
383      if split is not None:
384          base = OperationalList(base.split(split))
385          for i in range(len(base) - 1):
386              base[i] += split
387          new = OperationalList(new.split(split))
388          for i in range(len(new) - 1):
389              new[i] += split
390      deletions, additions = DIFF_ALGORITHMS[algorithm](base, new)
391      diffs = []
392      pointer = 0
393      append_to_last = False
394      for v in deletions:
395          if isinstance(v, bytes) or isinstance(v, OperationalList):
396              if append_to_last:
397                  diffs[-1].content += v
398                  pointer += len(v)
399              elif len(v):
400                  diffs.append(Diff(False, pointer, v))
401                  pointer += len(v)
402                  append_to_last = True
403          elif v:
404              pointer += v
405              append_to_last = False
406      pointer = 0
407      append_to_last = False
408      for v in additions:
409          if isinstance(v, bytes) or isinstance(v, OperationalList):
410              if append_to_last:
411                  diffs[-1].content += v
412              elif len(v):
413                  diffs.append(Diff(True, pointer, v))
414                  append_to_last = True
415          elif v:
416              pointer += v
417              append_to_last = False
418      return diffs
419  
420  
421  def byte_patch(base, diffs, split = None):  # type: (bytes, List[Diff], Optional[bytes]) -> bytes
422      if split is not None:
423          base = OperationalList(base.split(split))
424          for i in range(len(base) - 1):
425              base[i] += split
426      diffs = [Diff(diff.addition, diff.pointer, diff.content) for diff in diffs]
427      while len(diffs):
428          diff = diffs.pop(0)
429          content_len = len(diff.content)
430          if diff.addition:
431              base = base[:diff.pointer] + diff.content + base[diff.pointer:]
432              changes = content_len
433          else:
434              base = base[:diff.pointer] + base[diff.pointer + content_len:]
435              changes = -content_len
436          for diff2 in diffs:
437              if diff2.pointer > diff.pointer or diff.addition and diff2.pointer == diff.pointer:
438                  diff2.pointer += changes
439      if isinstance(base, bytes):
440          return base
441      byte_base = bytes()
442      for v in base:
443          byte_base += v
444      return byte_base
445  
446  
447  def unify_diffs(our_diff, their_diff):  # type: (List[Diff], List[Diff]) -> List[Diff]
448      sequences = list(our_diff)
449      sequences.extend(their_diff)
450      i = 0
451      while i < len(sequences) - 1:
452          j = i + 1
453          while j < len(sequences):
454              if (
455                  (
456                      not (
457                          sequences[i].pointer <= sequences[j].pointer < sequences[i].pointer + len(sequences[i])
458                          or sequences[j].pointer <= sequences[i].pointer < sequences[j].pointer + len(sequences[j])
459                      )
460                  )
461                  if not sequences[i].addition and not sequences[j].addition
462                  else (
463                      (sequences[i].pointer != sequences[j].pointer)
464                      or (
465                          sequences[i].addition
466                          and not sequences[j].addition
467                          and sequences[i].pointer == sequences[j].pointer
468                      )
469                  )
470              ):  # Checking for conflict (or lack thereof).
471                  j += 1
472              elif (
473                  not sequences[i].addition and sequences[j].addition and sequences[i].pointer == sequences[j].pointer
474              ):  # Checking to see if we can swap the order of our conflicting diffs in order to solve the conflict
475                  sequences.insert(j, sequences[i])
476                  sequences[i] = sequences.pop(j + 1)
477                  j = i + 1
478              elif sequences[j] in sequences[i]:  # Checking whether diff i contains diff j or is equal to it.
479                  sequences.pop(j)
480              elif sequences[i] in sequences[j]:  # Checking whether diff j contains diff i.
481                  sequences.pop(i)
482                  j = i + 1
483              else:  # last resort, hoping to find content overlap between our conflicting diffs
484                  try:
485                      temp = sequences[i] - sequences[j]
486                      if len(temp) == len(sequences[i]):
487                          raise ConflictError('Couldn\'t unify deltas for three-way merge')
488                      sequences[i] = temp + sequences[j]
489                      sequences.pop(j)
490                  except ConflictError:
491                      temp = sequences[j] - sequences[i]
492                      if len(temp) == len(sequences[j]):
493                          print(sequences[i], sequences[j])
494                          raise ConflictError('Couldn\'t unify deltas for three-way merge')
495                      sequences[j] = temp + sequences[i]
496                      sequences.pop(i)
497                      j = i + 1
498          i += 1
499      return sequences