1#ifndef SEAD_PTR_ARRAY_H_
2#define SEAD_PTR_ARRAY_H_
3
4#include <algorithm>
5#include <basis/seadRawPrint.h>
6#include <basis/seadTypes.h>
7#include <prim/seadMemUtil.h>
8#include <random/seadRandom.h>
9
10namespace sead
11{
12class Heap;
13class Random;
14
15class PtrArrayImpl
16{
17public:
18 PtrArrayImpl() = default;
19 PtrArrayImpl(s32 ptrNumMax, void* buf) { setBuffer(ptrNumMax, buf); }
20
21 void setBuffer(s32 ptrNumMax, void* buf);
22 void allocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*));
23 bool tryAllocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*));
24 void freeBuffer();
25 bool isBufferReady() const { return mPtrs != nullptr; }
26
27 bool isEmpty() const { return mPtrNum == 0; }
28 bool isFull() const { return mPtrNum >= mPtrNumMax; }
29
30 s32 size() const { return mPtrNum; }
31 s32 capacity() const { return mPtrNumMax; }
32
33 void erase(s32 position) { erase(position, count: 1); }
34 void erase(s32 position, s32 count);
35 void clear() { mPtrNum = 0; }
36
37 // TODO
38 void resize(s32 size);
39 // TODO
40 void unsafeResize(s32 size);
41
42 void swap(s32 pos1, s32 pos2)
43 {
44 auto* ptr = mPtrs[pos1];
45 mPtrs[pos1] = mPtrs[pos2];
46 mPtrs[pos2] = ptr;
47 }
48 void reverse();
49 void shuffle()
50 {
51 Random random;
52 shuffle(random: &random);
53 }
54 void shuffle(Random* random);
55
56protected:
57 using CompareCallbackImpl = int (*)(const void* a, const void* b);
58
59 void* at(s32 idx) const
60 {
61 if (u32(mPtrNum) <= u32(idx))
62 {
63 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mPtrNum);
64 return nullptr;
65 }
66 return mPtrs[idx];
67 }
68
69 void* unsafeAt(s32 idx) const { return mPtrs[idx]; }
70
71 // XXX: should this use at()?
72 void* front() const { return mPtrs[0]; }
73 void* back() const { return mPtrs[mPtrNum - 1]; }
74
75 void pushBack(void* ptr)
76 {
77 if (isFull())
78 {
79 SEAD_ASSERT_MSG(false, "list is full.");
80 return;
81 }
82 // Simplest insert case, so this is implemented directly without using insert().
83 mPtrs[mPtrNum] = ptr;
84 ++mPtrNum;
85 }
86
87 void pushFront(void* ptr) { insert(idx: 0, ptr); }
88
89 void* popBack() { return isEmpty() ? nullptr : mPtrs[--mPtrNum]; }
90
91 void* popFront()
92 {
93 if (isEmpty())
94 return nullptr;
95
96 void* result = mPtrs[0];
97 erase(position: 0);
98 return result;
99 }
100
101 void replace(s32 idx, void* ptr) { mPtrs[idx] = ptr; }
102
103 void* find(const void* ptr, CompareCallbackImpl cmp) const
104 {
105 for (s32 i = 0; i < mPtrNum; ++i)
106 {
107 if (cmp(mPtrs[i], ptr) == 0)
108 return mPtrs[i];
109 }
110 return nullptr;
111 }
112
113 s32 search(const void* ptr, CompareCallbackImpl cmp) const
114 {
115 for (s32 i = 0; i < mPtrNum; ++i)
116 {
117 if (cmp(mPtrs[i], ptr) == 0)
118 return i;
119 }
120 return -1;
121 }
122
123 bool equal(const PtrArrayImpl& other, CompareCallbackImpl cmp) const
124 {
125 if (mPtrNum != other.mPtrNum)
126 return false;
127
128 for (s32 i = 0; i < mPtrNum; ++i)
129 {
130 if (cmp(mPtrs[i], other.mPtrs[i]) != 0)
131 return false;
132 }
133 return true;
134 }
135
136 s32 indexOf(const void* ptr) const
137 {
138 for (s32 i = 0; i < mPtrNum; ++i)
139 {
140 if (mPtrs[i] == ptr)
141 return i;
142 }
143 return -1;
144 }
145
146 void createVacancy(s32 pos, s32 count)
147 {
148 if (mPtrNum <= pos)
149 return;
150
151 MemUtil::copyOverlap(dest: mPtrs + pos + count, src: mPtrs + pos,
152 size: s32(sizeof(void*)) * (mPtrNum - pos));
153 }
154
155 void insert(s32 idx, void* ptr);
156 void insertArray(s32 idx, void* array, s32 array_length, s32 elem_size);
157 bool checkInsert(s32 idx, s32 num);
158
159 template <typename T, typename Compare>
160 void sort_(Compare cmp)
161 {
162 // Note: Nintendo did not use <algorithm>
163 std::sort(mPtrs, mPtrs + size(), [&](const void* a, const void* b) {
164 return cmp(static_cast<const T*>(a), static_cast<const T*>(b)) < 0;
165 });
166 }
167
168 template <typename T, typename Compare>
169 void heapSort_(Compare cmp)
170 {
171 // Note: Nintendo did not use <algorithm>
172 const auto less_cmp = [&](const void* a, const void* b) {
173 return cmp(static_cast<const T*>(a), static_cast<const T*>(b)) < 0;
174 };
175 std::make_heap(mPtrs, mPtrs + size(), less_cmp);
176 std::sort_heap(mPtrs, mPtrs + size(), less_cmp);
177 }
178
179 void heapSort(CompareCallbackImpl cmp);
180
181 s32 compare(const PtrArrayImpl& other, CompareCallbackImpl cmp) const;
182 void uniq(CompareCallbackImpl cmp);
183
184 s32 binarySearch(const void* ptr, CompareCallbackImpl cmp) const
185 {
186 if (mPtrNum == 0)
187 return -1;
188
189 s32 a = 0;
190 s32 b = mPtrNum - 1;
191 while (a < b)
192 {
193 const s32 m = (a + b) / 2;
194 const s32 c = cmp(mPtrs[m], ptr);
195 if (c == 0)
196 return m;
197 if (c < 0)
198 a = m + 1;
199 else
200 b = m;
201 }
202
203 if (cmp(mPtrs[a], ptr) == 0)
204 return a;
205
206 return -1;
207 }
208
209 s32 mPtrNum = 0;
210 s32 mPtrNumMax = 0;
211 void** mPtrs = nullptr;
212};
213
214template <typename T>
215class PtrArray : public PtrArrayImpl
216{
217public:
218 PtrArray() = default;
219 PtrArray(s32 ptrNumMax, T** buf) : PtrArrayImpl(ptrNumMax, buf) {}
220
221 T* at(s32 pos) const { return static_cast<T*>(PtrArrayImpl::at(idx: pos)); }
222 T* unsafeAt(s32 pos) const { return static_cast<T*>(PtrArrayImpl::unsafeAt(idx: pos)); }
223 T* operator()(s32 pos) const { return unsafeAt(pos); }
224 T* operator[](s32 pos) const { return at(pos); }
225
226 // XXX: Does this use at()?
227 T* front() const { return at(pos: 0); }
228 T* back() const { return at(pos: mPtrNum - 1); }
229
230 void pushBack(T* ptr) { PtrArrayImpl::pushBack(ptr: constCast(ptr)); }
231 void pushFront(T* ptr) { PtrArrayImpl::pushFront(ptr: constCast(ptr)); }
232
233 T* popBack() { return static_cast<T*>(PtrArrayImpl::popBack()); }
234 T* popFront() { return static_cast<T*>(PtrArrayImpl::popFront()); }
235
236 void insert(s32 pos, T* ptr) { PtrArrayImpl::insert(idx: pos, ptr: constCast(ptr)); }
237 void insert(s32 pos, T* array, s32 count)
238 {
239 // XXX: is this right?
240 PtrArrayImpl::insertArray(idx: pos, array: constCast(ptr: array), array_length: count, elem_size: sizeof(T));
241 }
242 void replace(s32 pos, T* ptr) { PtrArrayImpl::replace(idx: pos, ptr: constCast(ptr)); }
243
244 s32 indexOf(const T* ptr) const { return PtrArrayImpl::indexOf(ptr); }
245
246 using CompareCallback = s32 (*)(const T*, const T*);
247
248 void sort() { sort(compareT); }
249 void sort(CompareCallback cmp) { PtrArrayImpl::sort_<T>(cmp); }
250 void heapSort() { heapSort(compareT); }
251 void heapSort(CompareCallback cmp) { PtrArrayImpl::heapSort_<T>(cmp); }
252
253 bool equal(const PtrArray& other, CompareCallback cmp) const
254 {
255 return PtrArrayImpl::equal(other, cmp);
256 }
257
258 s32 compare(const PtrArray& other, CompareCallback cmp) const
259 {
260 return PtrArrayImpl::compare(other, cmp);
261 }
262
263 T* find(const T* ptr) const
264 {
265 return PtrArrayImpl::find(ptr,
266 cmp: [](const void* a, const void* b) { return a == b ? 0 : -1; });
267 }
268 T* find(const T* ptr, CompareCallback cmp) const { return PtrArrayImpl::find(ptr, cmp); }
269 s32 search(const T* ptr) const
270 {
271 return PtrArrayImpl::search(ptr,
272 cmp: [](const void* a, const void* b) { return a == b ? 0 : -1; });
273 }
274 s32 search(const T* ptr, CompareCallback cmp) const { return PtrArrayImpl::search(ptr, cmp); }
275 s32 binarySearch(const T* ptr) const { return PtrArrayImpl::binarySearch(ptr, cmp: compareT); }
276 s32 binarySearch(const T* ptr, CompareCallback cmp) const
277 {
278 return PtrArrayImpl::binarySearch(ptr, cmp);
279 }
280
281 bool operator==(const PtrArray& other) const { return equal(other, cmp: compareT); }
282 bool operator!=(const PtrArray& other) const { return !(*this == other); }
283 bool operator<(const PtrArray& other) const { return compare(other) < 0; }
284 bool operator<=(const PtrArray& other) const { return compare(other) <= 0; }
285 bool operator>(const PtrArray& other) const { return compare(other) > 0; }
286 bool operator>=(const PtrArray& other) const { return compare(other) >= 0; }
287
288 void uniq() { PtrArrayImpl::uniq(cmp: compareT); }
289 void uniq(CompareCallback cmp) { PtrArrayImpl::uniq(cmp); }
290
291 class iterator
292 {
293 public:
294 iterator(T* const* pptr) : mPPtr{pptr} {}
295 bool operator==(const iterator& other) const { return mPPtr == other.mPPtr; }
296 bool operator!=(const iterator& other) const { return !(*this == other); }
297 iterator& operator++()
298 {
299 ++mPPtr;
300 return *this;
301 }
302 T& operator*() const { return **mPPtr; }
303 T* operator->() const { return *mPPtr; }
304
305 private:
306 T* const* mPPtr;
307 };
308
309 iterator begin() const { return iterator(data()); }
310 iterator end() const { return iterator(data() + mPtrNum); }
311
312 class constIterator
313 {
314 public:
315 constIterator(const T* const* pptr) : mPPtr{pptr} {}
316 bool operator==(const constIterator& other) const { return mPPtr == other.mPPtr; }
317 bool operator!=(const constIterator& other) const { return !(*this == other); }
318 constIterator& operator++()
319 {
320 ++mPPtr;
321 return *this;
322 }
323 const T& operator*() const { return **mPPtr; }
324 const T* operator->() const { return *mPPtr; }
325
326 private:
327 const T* const* mPPtr;
328 };
329
330 constIterator constBegin() const { return constIterator(data()); }
331 constIterator constEnd() const { return constIterator(data() + mPtrNum); }
332
333 T** data() const { return reinterpret_cast<T**>(mPtrs); }
334 T** dataBegin() const { return data(); }
335 T** dataEnd() const { return data() + mPtrNum; }
336
337protected:
338 static void* constCast(const T* ptr)
339 {
340 // Unfortunately, we need to cast away const because several PtrArrayImpl functions
341 // only take void* even though the pointed-to object isn't actually modified.
342 return static_cast<void*>(const_cast<std::remove_const_t<T>*>(ptr));
343 }
344
345 static int compareT(const void* a_, const void* b_)
346 {
347 const T* a = static_cast<const T*>(a_);
348 const T* b = static_cast<const T*>(b_);
349 if (*a < *b)
350 return -1;
351 if (*b < *a)
352 return 1;
353 return 0;
354 }
355};
356
357template <typename T, s32 N>
358class FixedPtrArray : public PtrArray<T>
359{
360public:
361 FixedPtrArray() : PtrArray<T>(N, mWork) {}
362
363 // These do not make sense for a *fixed* array.
364 void setBuffer(s32 ptrNumMax, void* buf) = delete;
365 void allocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
366 bool tryAllocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
367 void freeBuffer() = delete;
368
369private:
370 // Nintendo uses an untyped u8[N*sizeof(void*)] buffer. That is undefined behavior,
371 // so we will not do that.
372 T* mWork[N];
373};
374
375} // namespace sead
376
377#endif // SEAD_PTR_ARRAY_H_
378