1#pragma once
2
3#include "basis/seadRawPrint.h"
4#include "container/seadFreeList.h"
5#include "prim/seadBitUtil.h"
6#include "prim/seadDelegate.h"
7#include "prim/seadSafeString.h"
8
9namespace sead
10{
11template <typename Key>
12class TreeMapNode;
13
14/// Sorted associative container, implemented using a left-leaning red-black tree.
15/// For an explanation of the algorithm, see https://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf
16template <typename Key>
17class TreeMapImpl
18{
19public:
20 using Node = TreeMapNode<Key>;
21
22 void insert(Node* node);
23 void erase(const Key& key);
24 void clear();
25
26 Node* find(const Key& key) const { return find(mRoot, key); }
27
28 template <typename Callable>
29 void forEach(const Callable& callable) const
30 {
31 if (mRoot)
32 forEach(mRoot, callable);
33 }
34
35 Node* startIterating() const
36 {
37 if (!mRoot)
38 return nullptr;
39 return startIterating(mRoot);
40 }
41
42 Node* nextNode(Node* node) const
43 {
44 if (!node)
45 return nullptr;
46
47 // If there is a right child node, explore that branch first.
48 if (node->mRight)
49 {
50 node->mRight->setParent(node);
51 return startIterating(node->mRight);
52 }
53
54 // Otherwise, walk back up to the node P from which we reached this node
55 // by following P's left child pointer.
56 while (auto* const parent = node->getParent())
57 {
58 if (parent->mLeft == node)
59 return parent;
60 node = parent;
61 }
62 return nullptr;
63 }
64
65protected:
66 /// Returns the left most child of a given node, marking each node with its parent
67 /// along the way.
68 static Node* startIterating(Node* node)
69 {
70 while (node->mLeft)
71 {
72 node->mLeft->setParent(node);
73 node = node->mLeft;
74 }
75 return node;
76 }
77
78 Node* insert(Node* root, Node* node);
79 Node* erase(Node* root, const Key& key);
80 Node* find(Node* root, const Key& key) const;
81
82 static inline Node* rotateLeft(Node* node);
83 static inline Node* rotateRight(Node* node);
84 static inline Node* moveRedLeft(Node* node);
85 static inline Node* moveRedRight(Node* node);
86 static Node* findMin(Node* node);
87 static Node* eraseMin(Node* node);
88 static inline Node* fixUp(Node* node);
89 static bool isRed(const Node* node) { return node && node->isRed(); }
90 static inline void flipColors(Node* node);
91
92 template <typename Callable>
93 static void forEach(Node* start, const Callable& callable);
94
95 Node* mRoot = nullptr;
96};
97
98/// Requires Key to have a compare() member function, which returns -1 if lhs < rhs, 0 if lhs = rhs
99/// and 1 if lhs > rhs.
100template <typename Key>
101class TreeMapNode
102{
103public:
104 TreeMapNode()
105 {
106 mLeft = mRight = nullptr;
107 mColorAndPtr = 0;
108 }
109
110 virtual ~TreeMapNode() = default;
111 virtual void erase_() = 0;
112
113 const Key& key() const { return mKey; }
114
115protected:
116 friend class TreeMapImpl<Key>;
117
118 enum class Color
119 {
120 Red = 0,
121 Black = 1,
122 };
123
124 void flipColor() { BitUtil::bitCastWrite(value: mColorAndPtr ^ 1u, ptr: &mColorAndPtr); }
125 void setColor(Color color) { mColorAndPtr = uintptr_t(color); }
126
127 void setParent(TreeMapNode* parent) { mColorAndPtr = (mColorAndPtr & 1) | uintptr_t(parent); }
128 /// @warning Only valid if setParent has been called!
129 TreeMapNode* getParent() const { return reinterpret_cast<TreeMapNode*>(mColorAndPtr & ~1); }
130
131 bool isRed() const { return (mColorAndPtr & 1u) == bool(Color::Red); }
132
133 TreeMapNode* mLeft;
134 TreeMapNode* mRight;
135 uintptr_t mColorAndPtr;
136 Key mKey;
137};
138
139/// Requires Key to have operator< defined
140/// This can be specialized, but all specializations must define `compare` and `key` as follows.
141template <typename Key>
142struct TreeMapKeyImpl
143{
144 TreeMapKeyImpl() = default;
145 TreeMapKeyImpl(const Key& key_) : key(key_) {}
146 TreeMapKeyImpl& operator=(const Key& key_)
147 {
148 key = key_;
149 return *this;
150 }
151
152 /// Returns -1 if mKey < rhs, 0 if mKey = rhs and 1 if mKey > rhs.
153 s32 compare(const TreeMapKeyImpl& rhs) const
154 {
155 if (key < rhs.key)
156 return -1;
157 if (rhs.key < key)
158 return 1;
159 return 0;
160 }
161
162 Key key;
163};
164
165/// Sorted associative container.
166/// This is essentially std::map<Key, Value>
167template <typename Key, typename Value>
168class TreeMap : public TreeMapImpl<TreeMapKeyImpl<Key>>
169{
170public:
171 using MapImpl = TreeMapImpl<TreeMapKeyImpl<Key>>;
172 class Node : public MapImpl::Node
173 {
174 public:
175 Node(TreeMap* map, const Key& key, const Value& value) : mValue(value), mMap(map)
176 {
177 this->mKey = key;
178 }
179
180 void erase_() override;
181
182 Value& value() { return mValue; }
183 const Value& value() const { return mValue; }
184
185 private:
186 friend class TreeMap;
187
188 Value mValue;
189 TreeMap* mMap;
190 };
191
192 void allocBuffer(s32 node_max, Heap* heap, s32 alignment = sizeof(void*));
193 void setBuffer(s32 node_max, void* buffer);
194 void freeBuffer();
195
196 Value* insert(const Key& key, const Value& value);
197 void clear();
198
199 Node* find(const Key& key) const;
200
201 // Callable must have the signature Key&, Value&
202 template <typename Callable>
203 void forEach(const Callable& delegate) const;
204
205 Node* startIterating() const { return static_cast<Node*>(MapImpl::startIterating()); }
206 Node* nextNode(Node* node) const { return static_cast<Node*>(MapImpl::nextNode(node)); }
207
208private:
209 void eraseNodeForClear_(typename MapImpl::Node* node);
210
211 FreeList mFreeList;
212 s32 mSize = 0;
213 s32 mCapacity = 0;
214};
215
216template <typename Key, typename Value, int N>
217class FixedTreeMap : public TreeMap<Key, Value>
218{
219public:
220 FixedTreeMap() { TreeMap<Key, Value>::setBuffer(N, &mWork); }
221
222 void setBuffer(s32 ptrNumMax, void* buf) = delete;
223 void allocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
224 bool tryAllocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
225 void freeBuffer() = delete;
226
227private:
228 using NodeType = typename TreeMap<Key, Value>::Node;
229 static_assert(sizeof(NodeType) >= sizeof(void*));
230
231 alignas(std::max(a: alignof(NodeType), b: alignof(NodeType*))) u8 mWork[N * sizeof(NodeType)];
232};
233
234template <typename Key, typename Node>
235class IntrusiveTreeMap : public TreeMapImpl<Key>
236{
237public:
238 using MapImpl = TreeMapImpl<Key>;
239
240 Node* find(const Key& key) const { return static_cast<Node*>(MapImpl::find(key)); }
241
242 // Callable must have the signature Node*
243 template <typename Callable>
244 void forEach(const Callable& delegate) const
245 {
246 MapImpl::forEach([delegate](auto* base_node) {
247 auto* node = static_cast<Node*>(base_node);
248 delegate(node);
249 });
250 }
251
252 Node* startIterating() const { return static_cast<Node*>(MapImpl::startIterating()); }
253 Node* nextNode(Node* node) const { return static_cast<Node*>(MapImpl::nextNode(node)); }
254};
255
256template <typename Key>
257inline void TreeMapImpl<Key>::insert(Node* node)
258{
259 mRoot = insert(mRoot, node);
260 mRoot->setColor(Node::Color::Black);
261}
262
263template <typename Key>
264inline TreeMapNode<Key>* TreeMapImpl<Key>::insert(Node* root, Node* node)
265{
266 if (!root)
267 {
268 node->mLeft = node->mRight = nullptr;
269 node->setColor(Node::Color::Red);
270 return node;
271 }
272
273 const s32 cmp = node->key().compare(root->key());
274
275 if (cmp < 0)
276 {
277 root->mLeft = insert(root->mLeft, node);
278 }
279 else if (cmp > 0)
280 {
281 root->mRight = insert(root->mRight, node);
282 }
283 else if (root != node)
284 {
285 node->mRight = root->mRight;
286 node->mLeft = root->mLeft;
287 node->mColorAndPtr = root->mColorAndPtr;
288 root->erase_();
289 root = node;
290 }
291
292 if (isRed(node: root->mRight) && !isRed(node: root->mLeft))
293 root = rotateLeft(node: root);
294
295 if (isRed(node: root->mLeft) && isRed(node: root->mLeft->mLeft))
296 root = rotateRight(node: root);
297
298 if (isRed(node: root->mLeft) && isRed(node: root->mRight))
299 flipColors(node: root);
300
301 return root;
302}
303
304template <typename Key>
305inline void TreeMapImpl<Key>::erase(const Key& key)
306{
307 mRoot = erase(mRoot, key);
308 if (mRoot)
309 mRoot->setColor(Node::Color::Black);
310}
311
312template <typename Key>
313inline TreeMapNode<Key>* TreeMapImpl<Key>::erase(Node* root, const Key& key)
314{
315 if (key.compare(root->key()) < 0)
316 {
317 if (!isRed(node: root->mLeft) && !isRed(node: root->mLeft->mLeft))
318 root = moveRedLeft(node: root);
319 root->mLeft = erase(root->mLeft, key);
320 }
321 else
322 {
323 if (isRed(node: root->mLeft))
324 root = rotateRight(node: root);
325
326 if (key.compare(root->key()) == 0 && !root->mRight)
327 {
328 root->erase_();
329 return nullptr;
330 }
331
332 if (!isRed(node: root->mRight) && !isRed(node: root->mRight->mLeft))
333 root = moveRedRight(node: root);
334
335 if (key.compare(root->key()) == 0)
336 {
337 Node* const min_node = findMin(node: root->mRight);
338
339 Node* target = root->mRight;
340 if (root->mRight)
341 target = find(root->mRight, min_node->key());
342
343 target->mRight = eraseMin(node: root->mRight);
344 target->mLeft = root->mLeft;
345 target->mColorAndPtr = root->mColorAndPtr;
346 root->erase_();
347 root = target;
348 }
349 else
350 {
351 root->mRight = erase(root->mRight, key);
352 }
353 }
354 return fixUp(node: root);
355}
356
357template <typename Key>
358inline void TreeMapImpl<Key>::clear()
359{
360 mRoot = nullptr;
361}
362
363template <typename Key>
364inline TreeMapNode<Key>* TreeMapImpl<Key>::find(Node* root, const Key& key) const
365{
366 Node* node = root;
367 while (node)
368 {
369 const s32 cmp = key.compare(node->key());
370 if (cmp < 0)
371 node = node->mLeft;
372 else if (cmp > 0)
373 node = node->mRight;
374 else
375 return node;
376 }
377
378 return nullptr;
379}
380
381template <typename Key>
382template <typename Callable>
383inline void TreeMapImpl<Key>::forEach(Node* start, const Callable& callable)
384{
385 Node* i = start;
386 do
387 {
388 Node* node = i;
389 if (i->mLeft)
390 forEach(i->mLeft, callable);
391 i = i->mRight;
392 callable(node);
393 } while (i);
394}
395
396template <typename Key>
397inline TreeMapNode<Key>* TreeMapImpl<Key>::rotateLeft(Node* node)
398{
399 TreeMapNode<Key>* j = node->mRight;
400 node->mRight = j->mLeft;
401 j->mLeft = node;
402 j->mColorAndPtr = node->mColorAndPtr;
403 node->setColor(Node::Color::Red);
404 return j;
405}
406
407template <typename Key>
408inline TreeMapNode<Key>* TreeMapImpl<Key>::rotateRight(Node* node)
409{
410 TreeMapNode<Key>* j = node->mLeft;
411 node->mLeft = j->mRight;
412 j->mRight = node;
413 j->mColorAndPtr = node->mColorAndPtr;
414 node->setColor(Node::Color::Red);
415 return j;
416}
417
418// NON_MATCHING: this version matches the LLRB tree implementation and is better optimized;
419// there is a useless store to node->mRight in the original version
420template <typename Key>
421inline TreeMapNode<Key>* TreeMapImpl<Key>::moveRedLeft(Node* node)
422{
423 flipColors(node);
424 if (isRed(node: node->mRight->mLeft))
425 {
426 node->mRight = rotateRight(node: node->mRight);
427 node = rotateLeft(node);
428 flipColors(node);
429 }
430 return node;
431}
432
433template <typename Key>
434inline TreeMapNode<Key>* TreeMapImpl<Key>::moveRedRight(Node* node)
435{
436 flipColors(node);
437 if (isRed(node: node->mLeft->mLeft))
438 {
439 node = rotateRight(node);
440 flipColors(node);
441 }
442 return node;
443}
444
445template <typename Key>
446inline TreeMapNode<Key>* TreeMapImpl<Key>::findMin(Node* node)
447{
448 while (node->mLeft)
449 node = node->mLeft;
450 return node;
451}
452
453// NON_MATCHING: this version matches the LLRB tree implementation and is better optimized
454template <typename Key>
455inline TreeMapNode<Key>* TreeMapImpl<Key>::eraseMin(Node* node)
456{
457 if (!node->mLeft)
458 return nullptr;
459
460 if (!isRed(node: node->mLeft) && !isRed(node: node->mLeft->mLeft))
461 node = moveRedLeft(node);
462
463 node->mLeft = eraseMin(node: node->mLeft);
464#ifdef MATCHING_HACK_NX_CLANG
465 asm("");
466#endif
467 return fixUp(node);
468}
469
470template <typename Key>
471inline TreeMapNode<Key>* TreeMapImpl<Key>::fixUp(Node* node)
472{
473 if (isRed(node: node->mRight))
474 node = rotateLeft(node);
475
476 if (isRed(node: node->mLeft) && isRed(node: node->mLeft->mLeft))
477 node = rotateRight(node);
478
479 if (isRed(node: node->mLeft) && isRed(node: node->mRight))
480 flipColors(node);
481
482 return node;
483}
484
485template <typename Key>
486inline void TreeMapImpl<Key>::flipColors(Node* node)
487{
488 node->flipColor();
489 node->mLeft->flipColor();
490 node->mRight->flipColor();
491}
492
493template <typename Key, typename Value>
494inline void TreeMap<Key, Value>::Node::erase_()
495{
496 TreeMap* const map = mMap;
497 void* const this_ = this;
498 // Note: Nintendo does not call the destructor, which is dangerous...
499 map->mFreeList.free(ptr: this_);
500 --map->mSize;
501}
502
503template <typename Key, typename Value>
504inline void TreeMap<Key, Value>::allocBuffer(s32 node_max, Heap* heap, s32 alignment)
505{
506 s32 node_size = sizeof(Node);
507
508 SEAD_ASSERT(mFreeList.work() == nullptr);
509 if (node_max <= 0)
510 {
511 SEAD_ASSERT_MSG(false, "node_max[%d] must be larger than zero", node_max);
512 AllocFailAssert(heap, node_max * node_size, alignment);
513 }
514
515 void* work = AllocBuffer(size: node_max * node_size, heap, alignment);
516 if (work)
517 setBuffer(node_max, buffer: work);
518}
519
520template <typename Key, typename Value>
521inline void TreeMap<Key, Value>::setBuffer(s32 node_max, void* buffer)
522{
523 mCapacity = node_max;
524 mFreeList.setWork(work: buffer, elem_size: sizeof(Node), num: node_max);
525}
526
527template <typename Key, typename Value>
528inline void TreeMap<Key, Value>::freeBuffer()
529{
530 void* buffer = mFreeList.work();
531 if (!buffer)
532 return;
533
534 ::operator delete[](ptr: buffer);
535 mCapacity = 0;
536 mFreeList.reset();
537}
538
539template <typename Key, typename Value>
540inline Value* TreeMap<Key, Value>::insert(const Key& key, const Value& value)
541{
542 Value* ptr = nullptr;
543
544 if (mSize < mCapacity)
545 {
546 Node* node = new (mFreeList.alloc()) Node(this, key, value);
547 ptr = &node->value();
548 ++mSize;
549 MapImpl::insert(node);
550 }
551 else if (Node* node = find(key))
552 {
553 ptr = &node->value();
554 new (ptr) Value(value);
555 }
556 else
557 {
558 SEAD_ASSERT_MSG(false, "map is full.");
559 }
560
561 return ptr;
562}
563
564template <typename Key, typename Value>
565inline void TreeMap<Key, Value>::clear()
566{
567 Delegate1<TreeMap<Key, Value>, typename MapImpl::Node*> delegate(this,
568 &TreeMap::eraseNodeForClear_);
569 MapImpl::forEach(delegate);
570 mSize = 0;
571 MapImpl::clear();
572}
573
574template <typename Key, typename Value>
575inline typename TreeMap<Key, Value>::Node* TreeMap<Key, Value>::find(const Key& key) const
576{
577 return static_cast<Node*>(MapImpl::find(key));
578}
579
580template <typename Key, typename Value>
581template <typename Callable>
582inline void TreeMap<Key, Value>::forEach(const Callable& delegate) const
583{
584 MapImpl::forEach([&delegate](auto* base_node) {
585 auto* node = static_cast<Node*>(base_node);
586 delegate(node->key(), node->value());
587 });
588}
589
590template <typename Key, typename Value>
591inline void TreeMap<Key, Value>::eraseNodeForClear_(typename MapImpl::Node* node)
592{
593 // Note: Nintendo does not call the destructor, which is dangerous...
594 mFreeList.free(ptr: node);
595}
596} // namespace sead
597