IntervalTree.cs
1 using System; 2 using System.Collections.Generic; 3 4 namespace ARMeilleure.Translation 5 { 6 /// <summary> 7 /// An Augmented Interval Tree based off of the "TreeDictionary"'s Red-Black Tree. Allows fast overlap checking of ranges. 8 /// </summary> 9 /// <typeparam name="TK">Key</typeparam> 10 /// <typeparam name="TV">Value</typeparam> 11 public class IntervalTree<TK, TV> where TK : IComparable<TK> 12 { 13 private const int ArrayGrowthSize = 32; 14 15 private const bool Black = true; 16 private const bool Red = false; 17 private IntervalTreeNode<TK, TV> _root = null; 18 private int _count = 0; 19 20 public int Count => _count; 21 22 #region Public Methods 23 24 /// <summary> 25 /// Gets the values of the interval whose key is <paramref name="key"/>. 26 /// </summary> 27 /// <param name="key">Key of the node value to get</param> 28 /// <param name="value">Value with the given <paramref name="key"/></param> 29 /// <returns>True if the key is on the dictionary, false otherwise</returns> 30 public bool TryGet(TK key, out TV value) 31 { 32 IntervalTreeNode<TK, TV> node = GetNode(key); 33 34 if (node == null) 35 { 36 value = default; 37 return false; 38 } 39 40 value = node.Value; 41 return true; 42 } 43 44 /// <summary> 45 /// Returns the start addresses of the intervals whose start and end keys overlap the given range. 46 /// </summary> 47 /// <param name="start">Start of the range</param> 48 /// <param name="end">End of the range</param> 49 /// <param name="overlaps">Overlaps array to place results in</param> 50 /// <param name="overlapCount">Index to start writing results into the array. Defaults to 0</param> 51 /// <returns>Number of intervals found</returns> 52 public int Get(TK start, TK end, ref TK[] overlaps, int overlapCount = 0) 53 { 54 GetKeys(_root, start, end, ref overlaps, ref overlapCount); 55 56 return overlapCount; 57 } 58 59 /// <summary> 60 /// Adds a new interval into the tree whose start is <paramref name="start"/>, end is <paramref name="end"/> and value is <paramref name="value"/>. 61 /// </summary> 62 /// <param name="start">Start of the range to add</param> 63 /// <param name="end">End of the range to insert</param> 64 /// <param name="value">Value to add</param> 65 /// <param name="updateFactoryCallback">Optional factory used to create a new value if <paramref name="start"/> is already on the tree</param> 66 /// <exception cref="ArgumentNullException"><paramref name="value"/> is null</exception> 67 /// <returns>True if the value was added, false if the start key was already in the dictionary</returns> 68 public bool AddOrUpdate(TK start, TK end, TV value, Func<TK, TV, TV> updateFactoryCallback) 69 { 70 ArgumentNullException.ThrowIfNull(value); 71 72 return BSTInsert(start, end, value, updateFactoryCallback, out _); 73 } 74 75 /// <summary> 76 /// Gets an existing or adds a new interval into the tree whose start is <paramref name="start"/>, end is <paramref name="end"/> and value is <paramref name="value"/>. 77 /// </summary> 78 /// <param name="start">Start of the range to add</param> 79 /// <param name="end">End of the range to insert</param> 80 /// <param name="value">Value to add</param> 81 /// <exception cref="ArgumentNullException"><paramref name="value"/> is null</exception> 82 /// <returns><paramref name="value"/> if <paramref name="start"/> is not yet on the tree, or the existing value otherwise</returns> 83 public TV GetOrAdd(TK start, TK end, TV value) 84 { 85 ArgumentNullException.ThrowIfNull(value); 86 87 BSTInsert(start, end, value, null, out IntervalTreeNode<TK, TV> node); 88 return node.Value; 89 } 90 91 /// <summary> 92 /// Removes a value from the tree, searching for it with <paramref name="key"/>. 93 /// </summary> 94 /// <param name="key">Key of the node to remove</param> 95 /// <returns>Number of deleted values</returns> 96 public int Remove(TK key) 97 { 98 int removed = Delete(key); 99 100 _count -= removed; 101 102 return removed; 103 } 104 105 /// <summary> 106 /// Adds all the nodes in the dictionary into <paramref name="list"/>. 107 /// </summary> 108 /// <returns>A list of all values sorted by Key Order</returns> 109 public List<TV> AsList() 110 { 111 List<TV> list = new(); 112 113 AddToList(_root, list); 114 115 return list; 116 } 117 118 #endregion 119 120 #region Private Methods (BST) 121 122 /// <summary> 123 /// Adds all values that are children of or contained within <paramref name="node"/> into <paramref name="list"/>, in Key Order. 124 /// </summary> 125 /// <param name="node">The node to search for values within</param> 126 /// <param name="list">The list to add values to</param> 127 private void AddToList(IntervalTreeNode<TK, TV> node, List<TV> list) 128 { 129 if (node == null) 130 { 131 return; 132 } 133 134 AddToList(node.Left, list); 135 136 list.Add(node.Value); 137 138 AddToList(node.Right, list); 139 } 140 141 /// <summary> 142 /// Retrieve the node reference whose key is <paramref name="key"/>, or null if no such node exists. 143 /// </summary> 144 /// <param name="key">Key of the node to get</param> 145 /// <exception cref="ArgumentNullException"><paramref name="key"/> is null</exception> 146 /// <returns>Node reference in the tree</returns> 147 private IntervalTreeNode<TK, TV> GetNode(TK key) 148 { 149 ArgumentNullException.ThrowIfNull(key); 150 151 IntervalTreeNode<TK, TV> node = _root; 152 while (node != null) 153 { 154 int cmp = key.CompareTo(node.Start); 155 if (cmp < 0) 156 { 157 node = node.Left; 158 } 159 else if (cmp > 0) 160 { 161 node = node.Right; 162 } 163 else 164 { 165 return node; 166 } 167 } 168 return null; 169 } 170 171 /// <summary> 172 /// Retrieve all keys that overlap the given start and end keys. 173 /// </summary> 174 /// <param name="start">Start of the range</param> 175 /// <param name="end">End of the range</param> 176 /// <param name="overlaps">Overlaps array to place results in</param> 177 /// <param name="overlapCount">Overlaps count to update</param> 178 private void GetKeys(IntervalTreeNode<TK, TV> node, TK start, TK end, ref TK[] overlaps, ref int overlapCount) 179 { 180 if (node == null || start.CompareTo(node.Max) >= 0) 181 { 182 return; 183 } 184 185 GetKeys(node.Left, start, end, ref overlaps, ref overlapCount); 186 187 bool endsOnRight = end.CompareTo(node.Start) > 0; 188 if (endsOnRight) 189 { 190 if (start.CompareTo(node.End) < 0) 191 { 192 if (overlaps.Length <= overlapCount) 193 { 194 Array.Resize(ref overlaps, overlapCount + ArrayGrowthSize); 195 } 196 197 overlaps[overlapCount++] = node.Start; 198 } 199 200 GetKeys(node.Right, start, end, ref overlaps, ref overlapCount); 201 } 202 } 203 204 /// <summary> 205 /// Propagate an increase in max value starting at the given node, heading up the tree. 206 /// This should only be called if the max increases - not for rebalancing or removals. 207 /// </summary> 208 /// <param name="node">The node to start propagating from</param> 209 private static void PropagateIncrease(IntervalTreeNode<TK, TV> node) 210 { 211 TK max = node.Max; 212 IntervalTreeNode<TK, TV> ptr = node; 213 214 while ((ptr = ptr.Parent) != null) 215 { 216 if (max.CompareTo(ptr.Max) > 0) 217 { 218 ptr.Max = max; 219 } 220 else 221 { 222 break; 223 } 224 } 225 } 226 227 /// <summary> 228 /// Propagate recalculating max value starting at the given node, heading up the tree. 229 /// This fully recalculates the max value from all children when there is potential for it to decrease. 230 /// </summary> 231 /// <param name="node">The node to start propagating from</param> 232 private static void PropagateFull(IntervalTreeNode<TK, TV> node) 233 { 234 IntervalTreeNode<TK, TV> ptr = node; 235 236 do 237 { 238 TK max = ptr.End; 239 240 if (ptr.Left != null && ptr.Left.Max.CompareTo(max) > 0) 241 { 242 max = ptr.Left.Max; 243 } 244 245 if (ptr.Right != null && ptr.Right.Max.CompareTo(max) > 0) 246 { 247 max = ptr.Right.Max; 248 } 249 250 ptr.Max = max; 251 } while ((ptr = ptr.Parent) != null); 252 } 253 254 /// <summary> 255 /// Insertion Mechanism for the interval tree. Similar to a BST insert, with the start of the range as the key. 256 /// Iterates the tree starting from the root and inserts a new node where all children in the left subtree are less than <paramref name="start"/>, and all children in the right subtree are greater than <paramref name="start"/>. 257 /// Each node can contain multiple values, and has an end address which is the maximum of all those values. 258 /// Post insertion, the "max" value of the node and all parents are updated. 259 /// </summary> 260 /// <param name="start">Start of the range to insert</param> 261 /// <param name="end">End of the range to insert</param> 262 /// <param name="value">Value to insert</param> 263 /// <param name="updateFactoryCallback">Optional factory used to create a new value if <paramref name="start"/> is already on the tree</param> 264 /// <param name="outNode">Node that was inserted or modified</param> 265 /// <returns>True if <paramref name="start"/> was not yet on the tree, false otherwise</returns> 266 private bool BSTInsert(TK start, TK end, TV value, Func<TK, TV, TV> updateFactoryCallback, out IntervalTreeNode<TK, TV> outNode) 267 { 268 IntervalTreeNode<TK, TV> parent = null; 269 IntervalTreeNode<TK, TV> node = _root; 270 271 while (node != null) 272 { 273 parent = node; 274 int cmp = start.CompareTo(node.Start); 275 if (cmp < 0) 276 { 277 node = node.Left; 278 } 279 else if (cmp > 0) 280 { 281 node = node.Right; 282 } 283 else 284 { 285 outNode = node; 286 287 if (updateFactoryCallback != null) 288 { 289 // Replace 290 node.Value = updateFactoryCallback(start, node.Value); 291 292 int endCmp = end.CompareTo(node.End); 293 294 if (endCmp > 0) 295 { 296 node.End = end; 297 if (end.CompareTo(node.Max) > 0) 298 { 299 node.Max = end; 300 PropagateIncrease(node); 301 RestoreBalanceAfterInsertion(node); 302 } 303 } 304 else if (endCmp < 0) 305 { 306 node.End = end; 307 PropagateFull(node); 308 } 309 } 310 311 return false; 312 } 313 } 314 IntervalTreeNode<TK, TV> newNode = new(start, end, value, parent); 315 if (newNode.Parent == null) 316 { 317 _root = newNode; 318 } 319 else if (start.CompareTo(parent.Start) < 0) 320 { 321 parent.Left = newNode; 322 } 323 else 324 { 325 parent.Right = newNode; 326 } 327 328 PropagateIncrease(newNode); 329 _count++; 330 RestoreBalanceAfterInsertion(newNode); 331 outNode = newNode; 332 return true; 333 } 334 335 /// <summary> 336 /// Removes the value from the dictionary after searching for it with <paramref name="key"/>. 337 /// </summary> 338 /// <param name="key">Key to search for</param> 339 /// <returns>Number of deleted values</returns> 340 private int Delete(TK key) 341 { 342 IntervalTreeNode<TK, TV> nodeToDelete = GetNode(key); 343 344 if (nodeToDelete == null) 345 { 346 return 0; 347 } 348 349 IntervalTreeNode<TK, TV> replacementNode; 350 351 if (LeftOf(nodeToDelete) == null || RightOf(nodeToDelete) == null) 352 { 353 replacementNode = nodeToDelete; 354 } 355 else 356 { 357 replacementNode = PredecessorOf(nodeToDelete); 358 } 359 360 IntervalTreeNode<TK, TV> tmp = LeftOf(replacementNode) ?? RightOf(replacementNode); 361 362 if (tmp != null) 363 { 364 tmp.Parent = ParentOf(replacementNode); 365 } 366 367 if (ParentOf(replacementNode) == null) 368 { 369 _root = tmp; 370 } 371 else if (replacementNode == LeftOf(ParentOf(replacementNode))) 372 { 373 ParentOf(replacementNode).Left = tmp; 374 } 375 else 376 { 377 ParentOf(replacementNode).Right = tmp; 378 } 379 380 if (replacementNode != nodeToDelete) 381 { 382 nodeToDelete.Start = replacementNode.Start; 383 nodeToDelete.Value = replacementNode.Value; 384 nodeToDelete.End = replacementNode.End; 385 nodeToDelete.Max = replacementNode.Max; 386 } 387 388 PropagateFull(replacementNode); 389 390 if (tmp != null && ColorOf(replacementNode) == Black) 391 { 392 RestoreBalanceAfterRemoval(tmp); 393 } 394 395 return 1; 396 } 397 398 /// <summary> 399 /// Returns the node with the largest key where <paramref name="node"/> is considered the root node. 400 /// </summary> 401 /// <param name="node">Root Node</param> 402 /// <returns>Node with the maximum key in the tree of <paramref name="node"/></returns> 403 private static IntervalTreeNode<TK, TV> Maximum(IntervalTreeNode<TK, TV> node) 404 { 405 IntervalTreeNode<TK, TV> tmp = node; 406 while (tmp.Right != null) 407 { 408 tmp = tmp.Right; 409 } 410 411 return tmp; 412 } 413 414 /// <summary> 415 /// Finds the node whose key is immediately less than <paramref name="node"/>. 416 /// </summary> 417 /// <param name="node">Node to find the predecessor of</param> 418 /// <returns>Predecessor of <paramref name="node"/></returns> 419 private static IntervalTreeNode<TK, TV> PredecessorOf(IntervalTreeNode<TK, TV> node) 420 { 421 if (node.Left != null) 422 { 423 return Maximum(node.Left); 424 } 425 IntervalTreeNode<TK, TV> parent = node.Parent; 426 while (parent != null && node == parent.Left) 427 { 428 node = parent; 429 parent = parent.Parent; 430 } 431 return parent; 432 } 433 434 #endregion 435 436 #region Private Methods (RBL) 437 438 private void RestoreBalanceAfterRemoval(IntervalTreeNode<TK, TV> balanceNode) 439 { 440 IntervalTreeNode<TK, TV> ptr = balanceNode; 441 442 while (ptr != _root && ColorOf(ptr) == Black) 443 { 444 if (ptr == LeftOf(ParentOf(ptr))) 445 { 446 IntervalTreeNode<TK, TV> sibling = RightOf(ParentOf(ptr)); 447 448 if (ColorOf(sibling) == Red) 449 { 450 SetColor(sibling, Black); 451 SetColor(ParentOf(ptr), Red); 452 RotateLeft(ParentOf(ptr)); 453 sibling = RightOf(ParentOf(ptr)); 454 } 455 if (ColorOf(LeftOf(sibling)) == Black && ColorOf(RightOf(sibling)) == Black) 456 { 457 SetColor(sibling, Red); 458 ptr = ParentOf(ptr); 459 } 460 else 461 { 462 if (ColorOf(RightOf(sibling)) == Black) 463 { 464 SetColor(LeftOf(sibling), Black); 465 SetColor(sibling, Red); 466 RotateRight(sibling); 467 sibling = RightOf(ParentOf(ptr)); 468 } 469 SetColor(sibling, ColorOf(ParentOf(ptr))); 470 SetColor(ParentOf(ptr), Black); 471 SetColor(RightOf(sibling), Black); 472 RotateLeft(ParentOf(ptr)); 473 ptr = _root; 474 } 475 } 476 else 477 { 478 IntervalTreeNode<TK, TV> sibling = LeftOf(ParentOf(ptr)); 479 480 if (ColorOf(sibling) == Red) 481 { 482 SetColor(sibling, Black); 483 SetColor(ParentOf(ptr), Red); 484 RotateRight(ParentOf(ptr)); 485 sibling = LeftOf(ParentOf(ptr)); 486 } 487 if (ColorOf(RightOf(sibling)) == Black && ColorOf(LeftOf(sibling)) == Black) 488 { 489 SetColor(sibling, Red); 490 ptr = ParentOf(ptr); 491 } 492 else 493 { 494 if (ColorOf(LeftOf(sibling)) == Black) 495 { 496 SetColor(RightOf(sibling), Black); 497 SetColor(sibling, Red); 498 RotateLeft(sibling); 499 sibling = LeftOf(ParentOf(ptr)); 500 } 501 SetColor(sibling, ColorOf(ParentOf(ptr))); 502 SetColor(ParentOf(ptr), Black); 503 SetColor(LeftOf(sibling), Black); 504 RotateRight(ParentOf(ptr)); 505 ptr = _root; 506 } 507 } 508 } 509 SetColor(ptr, Black); 510 } 511 512 private void RestoreBalanceAfterInsertion(IntervalTreeNode<TK, TV> balanceNode) 513 { 514 SetColor(balanceNode, Red); 515 while (balanceNode != null && balanceNode != _root && ColorOf(ParentOf(balanceNode)) == Red) 516 { 517 if (ParentOf(balanceNode) == LeftOf(ParentOf(ParentOf(balanceNode)))) 518 { 519 IntervalTreeNode<TK, TV> sibling = RightOf(ParentOf(ParentOf(balanceNode))); 520 521 if (ColorOf(sibling) == Red) 522 { 523 SetColor(ParentOf(balanceNode), Black); 524 SetColor(sibling, Black); 525 SetColor(ParentOf(ParentOf(balanceNode)), Red); 526 balanceNode = ParentOf(ParentOf(balanceNode)); 527 } 528 else 529 { 530 if (balanceNode == RightOf(ParentOf(balanceNode))) 531 { 532 balanceNode = ParentOf(balanceNode); 533 RotateLeft(balanceNode); 534 } 535 SetColor(ParentOf(balanceNode), Black); 536 SetColor(ParentOf(ParentOf(balanceNode)), Red); 537 RotateRight(ParentOf(ParentOf(balanceNode))); 538 } 539 } 540 else 541 { 542 IntervalTreeNode<TK, TV> sibling = LeftOf(ParentOf(ParentOf(balanceNode))); 543 544 if (ColorOf(sibling) == Red) 545 { 546 SetColor(ParentOf(balanceNode), Black); 547 SetColor(sibling, Black); 548 SetColor(ParentOf(ParentOf(balanceNode)), Red); 549 balanceNode = ParentOf(ParentOf(balanceNode)); 550 } 551 else 552 { 553 if (balanceNode == LeftOf(ParentOf(balanceNode))) 554 { 555 balanceNode = ParentOf(balanceNode); 556 RotateRight(balanceNode); 557 } 558 SetColor(ParentOf(balanceNode), Black); 559 SetColor(ParentOf(ParentOf(balanceNode)), Red); 560 RotateLeft(ParentOf(ParentOf(balanceNode))); 561 } 562 } 563 } 564 SetColor(_root, Black); 565 } 566 567 private void RotateLeft(IntervalTreeNode<TK, TV> node) 568 { 569 if (node != null) 570 { 571 IntervalTreeNode<TK, TV> right = RightOf(node); 572 node.Right = LeftOf(right); 573 if (node.Right != null) 574 { 575 node.Right.Parent = node; 576 } 577 IntervalTreeNode<TK, TV> nodeParent = ParentOf(node); 578 right.Parent = nodeParent; 579 if (nodeParent == null) 580 { 581 _root = right; 582 } 583 else if (node == LeftOf(nodeParent)) 584 { 585 nodeParent.Left = right; 586 } 587 else 588 { 589 nodeParent.Right = right; 590 } 591 right.Left = node; 592 node.Parent = right; 593 594 PropagateFull(node); 595 } 596 } 597 598 private void RotateRight(IntervalTreeNode<TK, TV> node) 599 { 600 if (node != null) 601 { 602 IntervalTreeNode<TK, TV> left = LeftOf(node); 603 node.Left = RightOf(left); 604 if (node.Left != null) 605 { 606 node.Left.Parent = node; 607 } 608 IntervalTreeNode<TK, TV> nodeParent = ParentOf(node); 609 left.Parent = nodeParent; 610 if (nodeParent == null) 611 { 612 _root = left; 613 } 614 else if (node == RightOf(nodeParent)) 615 { 616 nodeParent.Right = left; 617 } 618 else 619 { 620 nodeParent.Left = left; 621 } 622 left.Right = node; 623 node.Parent = left; 624 625 PropagateFull(node); 626 } 627 } 628 629 #endregion 630 631 #region Safety-Methods 632 633 // These methods save memory by allowing us to forego sentinel nil nodes, as well as serve as protection against NullReferenceExceptions. 634 635 /// <summary> 636 /// Returns the color of <paramref name="node"/>, or Black if it is null. 637 /// </summary> 638 /// <param name="node">Node</param> 639 /// <returns>The boolean color of <paramref name="node"/>, or black if null</returns> 640 private static bool ColorOf(IntervalTreeNode<TK, TV> node) 641 { 642 return node == null || node.Color; 643 } 644 645 /// <summary> 646 /// Sets the color of <paramref name="node"/> node to <paramref name="color"/>. 647 /// <br></br> 648 /// This method does nothing if <paramref name="node"/> is null. 649 /// </summary> 650 /// <param name="node">Node to set the color of</param> 651 /// <param name="color">Color (Boolean)</param> 652 private static void SetColor(IntervalTreeNode<TK, TV> node, bool color) 653 { 654 if (node != null) 655 { 656 node.Color = color; 657 } 658 } 659 660 /// <summary> 661 /// This method returns the left node of <paramref name="node"/>, or null if <paramref name="node"/> is null. 662 /// </summary> 663 /// <param name="node">Node to retrieve the left child from</param> 664 /// <returns>Left child of <paramref name="node"/></returns> 665 private static IntervalTreeNode<TK, TV> LeftOf(IntervalTreeNode<TK, TV> node) 666 { 667 return node?.Left; 668 } 669 670 /// <summary> 671 /// This method returns the right node of <paramref name="node"/>, or null if <paramref name="node"/> is null. 672 /// </summary> 673 /// <param name="node">Node to retrieve the right child from</param> 674 /// <returns>Right child of <paramref name="node"/></returns> 675 private static IntervalTreeNode<TK, TV> RightOf(IntervalTreeNode<TK, TV> node) 676 { 677 return node?.Right; 678 } 679 680 /// <summary> 681 /// Returns the parent node of <paramref name="node"/>, or null if <paramref name="node"/> is null. 682 /// </summary> 683 /// <param name="node">Node to retrieve the parent from</param> 684 /// <returns>Parent of <paramref name="node"/></returns> 685 private static IntervalTreeNode<TK, TV> ParentOf(IntervalTreeNode<TK, TV> node) 686 { 687 return node?.Parent; 688 } 689 690 #endregion 691 692 public bool ContainsKey(TK key) 693 { 694 return GetNode(key) != null; 695 } 696 697 public void Clear() 698 { 699 _root = null; 700 _count = 0; 701 } 702 } 703 704 /// <summary> 705 /// Represents a node in the IntervalTree which contains start and end keys of type K, and a value of generic type V. 706 /// </summary> 707 /// <typeparam name="TK">Key type of the node</typeparam> 708 /// <typeparam name="TV">Value type of the node</typeparam> 709 class IntervalTreeNode<TK, TV> 710 { 711 public bool Color = true; 712 public IntervalTreeNode<TK, TV> Left = null; 713 public IntervalTreeNode<TK, TV> Right = null; 714 public IntervalTreeNode<TK, TV> Parent = null; 715 716 /// <summary> 717 /// The start of the range. 718 /// </summary> 719 public TK Start; 720 721 /// <summary> 722 /// The end of the range. 723 /// </summary> 724 public TK End; 725 726 /// <summary> 727 /// The maximum end value of this node and all its children. 728 /// </summary> 729 public TK Max; 730 731 /// <summary> 732 /// Value stored on this node. 733 /// </summary> 734 public TV Value; 735 736 public IntervalTreeNode(TK start, TK end, TV value, IntervalTreeNode<TK, TV> parent) 737 { 738 Start = start; 739 End = end; 740 Max = end; 741 Value = value; 742 Parent = parent; 743 } 744 } 745 }