1#ifndef SEAD_BUFFER_H_
2#define SEAD_BUFFER_H_
3
4#include <algorithm>
5#include <type_traits>
6
7#include <basis/seadNew.h>
8#include <basis/seadRawPrint.h>
9#include <basis/seadTypes.h>
10#include <prim/seadPtrUtil.h>
11
12namespace sead
13{
14class Heap;
15
16template <typename T>
17class Buffer
18{
19public:
20 Buffer() : mSize(0), mBuffer(NULL) {}
21 Buffer(s32 size, T* buffer) : mSize(size), mBuffer(buffer) {}
22 template <s32 N>
23 Buffer(T (&array)[N]) : Buffer(N, array)
24 {
25 }
26
27 Buffer(const Buffer& other) { *this = other; }
28
29 Buffer& operator=(const Buffer& other)
30 {
31 mSize = other.mSize;
32 mBuffer = other.mBuffer;
33 return *this;
34 }
35
36 class iterator
37 {
38 public:
39 explicit iterator(T* buffer, s32 index = 0) : mIndex(index), mBuffer(buffer) {}
40 bool operator==(const iterator& rhs) const
41 {
42 return mIndex == rhs.mIndex && mBuffer == rhs.mBuffer;
43 }
44 bool operator!=(const iterator& rhs) const { return !operator==(rhs); }
45 iterator& operator++()
46 {
47 ++mIndex;
48 return *this;
49 }
50 T& operator*() const { return mBuffer[mIndex]; }
51 T* operator->() const { return &mBuffer[mIndex]; }
52 s32 getIndex() const { return mIndex; }
53
54 private:
55 s32 mIndex;
56 T* mBuffer;
57 };
58
59 class constIterator
60 {
61 public:
62 explicit constIterator(const T* buffer, s32 index = 0) : mIndex(index), mBuffer(buffer) {}
63 bool operator==(const constIterator& rhs) const
64 {
65 return mIndex == rhs.mIndex && mBuffer == rhs.mBuffer;
66 }
67 bool operator!=(const constIterator& rhs) const { return !operator==(rhs); }
68 constIterator& operator++()
69 {
70 ++mIndex;
71 return *this;
72 }
73 const T& operator*() const { return mBuffer[mIndex]; }
74 const T* operator->() const { return &mBuffer[mIndex]; }
75 s32 getIndex() const { return mIndex; }
76
77 private:
78 s32 mIndex;
79 const T* mBuffer;
80 };
81
82 iterator begin() { return iterator(mBuffer); }
83 iterator begin(s32 idx)
84 {
85 if (u32(size()) < u32(idx))
86 {
87 SEAD_ASSERT_MSG(false, "range over [0,%d] : %d", size(), idx);
88 return end();
89 }
90 return iterator(mBuffer, idx);
91 }
92
93 constIterator begin() const { return constIterator(mBuffer); }
94 constIterator begin(s32 idx) const
95 {
96 if (u32(size()) < u32(idx))
97 {
98 SEAD_ASSERT_MSG(false, "range over [0,%d] : %d", size(), idx);
99 return end();
100 }
101 return constIterator(mBuffer, idx);
102 }
103
104 iterator end() { return iterator(mBuffer, mSize); }
105 constIterator end() const { return constIterator(mBuffer, mSize); }
106
107 class reverseIterator
108 {
109 public:
110 explicit reverseIterator(T* buffer, s32 index = 0) : mIndex(index), mBuffer(buffer) {}
111 bool operator==(const reverseIterator& rhs) const { return mIndex == rhs.mIndex; }
112 bool operator!=(const reverseIterator& rhs) const { return !operator==(rhs); }
113 reverseIterator& operator++()
114 {
115 --mIndex;
116 return *this;
117 }
118 T& operator*() const { return mBuffer[mIndex]; }
119 T* operator->() const { return &mBuffer[mIndex]; }
120 s32 getIndex() const { return mIndex; }
121
122 private:
123 s32 mIndex;
124 T* mBuffer;
125 };
126
127 reverseIterator rbegin() { return reverseIterator(mBuffer, mSize - 1); }
128 reverseIterator rbegin(s32 index) { return reverseIterator(mBuffer, index); }
129 reverseIterator rend() { return reverseIterator(mBuffer, -1); }
130
131 void allocBuffer(s32 size, s32 alignment)
132 {
133 SEAD_ASSERT(mBuffer == nullptr);
134 if (size > 0)
135 {
136 T* buffer = new (alignment) T[size];
137 if (buffer)
138 {
139 mSize = size;
140 mBuffer = buffer;
141 SEAD_ASSERT_MSG(PtrUtil::isAlignedPow2(mBuffer, abs(alignment)),
142 "don't set alignment for a class with destructor");
143 }
144 }
145 else
146 {
147 SEAD_ASSERT_MSG(false, "size[%d] must be larger than zero", size);
148 }
149 }
150
151 void allocBuffer(s32 size, Heap* heap, s32 alignment = sizeof(void*))
152 {
153 SEAD_ASSERT(mBuffer == nullptr);
154 if (size > 0)
155 {
156 T* buffer = new (heap, alignment) T[size];
157 if (buffer)
158 {
159 mSize = size;
160 mBuffer = buffer;
161 SEAD_ASSERT_MSG(PtrUtil::isAlignedPow2(mBuffer, abs(alignment)),
162 "don't set alignment for a class with destructor");
163 }
164 }
165 else
166 {
167 SEAD_ASSERT_MSG(false, "size[%d] must be larger than zero", size);
168 }
169 }
170
171 bool tryAllocBuffer(s32 size, s32 alignment = sizeof(void*))
172 {
173 SEAD_ASSERT(mBuffer == nullptr);
174 if (size > 0)
175 {
176 T* buffer = new (alignment, std::nothrow) T[size];
177 if (buffer)
178 {
179 mSize = size;
180 mBuffer = buffer;
181 SEAD_ASSERT_MSG(PtrUtil::isAlignedPow2(mBuffer, abs(alignment)),
182 "don't set alignment for a class with destructor");
183 return true;
184 }
185 return false;
186 }
187 SEAD_ASSERT_MSG(false, "size[%d] must be larger than zero", size);
188 return false;
189 }
190
191 bool tryAllocBuffer(s32 size, Heap* heap, s32 alignment = sizeof(void*))
192 {
193 SEAD_ASSERT(mBuffer == nullptr);
194 if (size > 0)
195 {
196 T* buffer = new (heap, alignment, std::nothrow) T[size];
197 if (buffer)
198 {
199 mSize = size;
200 mBuffer = buffer;
201 SEAD_ASSERT_MSG(PtrUtil::isAlignedPow2(mBuffer, abs(alignment)),
202 "don't set alignment for a class with destructor");
203 return true;
204 }
205 return false;
206 }
207 SEAD_ASSERT_MSG(false, "size[%d] must be larger than zero", size);
208 return false;
209 }
210
211 inline bool allocBufferAssert(s32 size, Heap* heap, s32 alignment = sizeof(void*))
212 {
213 if (tryAllocBuffer(size, heap, alignment))
214 return true;
215 AllocFailAssert(heap, sizeof(T) * size, alignment);
216 return false;
217 }
218
219 void freeBuffer()
220 {
221 if (mBuffer)
222 {
223 delete[] mBuffer;
224 mBuffer = nullptr;
225 mSize = 0;
226 }
227 }
228
229 void setBuffer(s32 size, T* bufferptr)
230 {
231 if (size < 1)
232 {
233 SEAD_ASSERT_MSG(false, "size[%d] must be larger than zero", size);
234 return;
235 }
236 if (!bufferptr)
237 {
238 SEAD_ASSERT_MSG(false, "bufferptr is null");
239 return;
240 }
241 mSize = size;
242 mBuffer = bufferptr;
243 }
244
245 bool isBufferReady() const { return mBuffer != nullptr; }
246
247 bool isIndexValid(s32 idx) const { return u32(idx) < u32(mSize); }
248
249 T& operator()(s32 idx) { return *unsafeGet(idx); }
250 const T& operator()(s32 idx) const { return *unsafeGet(idx); }
251
252 T& operator[](s32 idx)
253 {
254 if (u32(mSize) <= u32(idx))
255 {
256 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mSize);
257 return mBuffer[0];
258 }
259 return mBuffer[idx];
260 }
261
262 const T& operator[](s32 idx) const
263 {
264 if (u32(mSize) <= u32(idx))
265 {
266 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mSize);
267 return mBuffer[0];
268 }
269 return mBuffer[idx];
270 }
271
272 T* get(s32 idx)
273 {
274 if (u32(mSize) <= u32(idx))
275 {
276 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mSize);
277 return nullptr;
278 }
279 return &mBuffer[idx];
280 }
281
282 const T* get(s32 idx) const
283 {
284 if (u32(mSize) <= u32(idx))
285 {
286 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mSize);
287 return nullptr;
288 }
289 return &mBuffer[idx];
290 }
291
292 T* unsafeGet(s32 idx)
293 {
294 SEAD_ASSERT_MSG(u32(idx) < u32(mSize), "index exceeded [%d/%d]", idx, mSize);
295 return &mBuffer[idx];
296 }
297 const T* unsafeGet(s32 idx) const
298 {
299 SEAD_ASSERT_MSG(u32(idx) < u32(mSize), "index exceeded [%d/%d]", idx, mSize);
300 return &mBuffer[idx];
301 }
302
303 T& front() { return mBuffer[0]; }
304 const T& front() const { return mBuffer[0]; }
305
306 T& back() { return mBuffer[mSize - 1]; }
307 const T& back() const { return mBuffer[mSize - 1]; }
308
309 s32 size() const { return mSize; }
310 s32 getSize() const { return mSize; }
311
312 T* getBufferPtr() { return mBuffer; }
313 const T* getBufferPtr() const { return mBuffer; }
314
315 u32 getByteSize() const { return mSize * sizeof(T); }
316
317 void fill(const T& v)
318 {
319 for (s32 i = 0, n = mSize; i < n; ++i)
320 mBuffer[i] = v;
321 }
322
323 using CompareCallback = s32 (*)(const T* lhs, const T* rhs);
324
325 s32 binarySearch(const T& item) const { return binarySearch(item, compareT); }
326
327 s32 binarySearch(const T& item, CompareCallback cmp) const
328 {
329 if (mSize == 0)
330 return -1;
331
332 s32 a = 0;
333 s32 b = mSize - 1;
334 while (a < b)
335 {
336 const s32 m = (a + b) / 2;
337 const s32 c = cmp(&mBuffer[m], &item);
338 if (c == 0)
339 return m;
340 if (c < 0)
341 a = m + 1;
342 else
343 b = m;
344 }
345
346 if (cmp(&mBuffer[a], &item) == 0)
347 return a;
348
349 return -1;
350 }
351
352 template <typename Key>
353 s32 binarySearch(const Key& key, s32 (*cmp)(const T& item, const Key& key)) const
354 {
355 if (mSize == 0)
356 return -1;
357
358 s32 a = 0;
359 s32 b = mSize - 1;
360 while (a < b)
361 {
362 const s32 m = (a + b) / 2;
363 const s32 c = cmp(mBuffer[m], key);
364 if (c == 0)
365 return m;
366 if (c < 0)
367 a = m + 1;
368 else
369 b = m;
370 }
371
372 if (cmp(mBuffer[a], key) == 0)
373 return a;
374
375 return -1;
376 }
377
378 template <typename CustomCompareCallback>
379 s32 binarySearchC(CustomCompareCallback cmp) const
380 {
381 if (mSize == 0)
382 return -1;
383
384 s32 a = 0;
385 s32 b = mSize - 1;
386 while (a < b)
387 {
388 const s32 m = (a + b) / 2;
389 const s32 c = cmp(mBuffer[m]);
390 if (c == 0)
391 return m;
392 if (c < 0)
393 a = m + 1;
394 else
395 b = m;
396 }
397
398 if (cmp(mBuffer[a]) == 0)
399 return a;
400
401 return -1;
402 }
403
404 /// Sort elements with indices in [start_idx .. end_idx] using heapsort.
405 void heapSort(s32 start_idx, s32 end_idx)
406 {
407 if (start_idx >= mSize || end_idx >= mSize || end_idx - start_idx < 1)
408 return;
409 // FIXME: Nintendo implemented heap sort manually without using <algorithm>
410 std::make_heap(mBuffer + start_idx, mBuffer + end_idx);
411 std::sort_heap(mBuffer + start_idx, mBuffer + end_idx);
412 }
413
414 /// Sort elements with indices in [start_idx .. end_idx] using heapsort.
415 void heapSort(s32 start_idx, s32 end_idx, CompareCallback cmp)
416 {
417 if (start_idx >= mSize || end_idx >= mSize || end_idx - start_idx < 1)
418 return;
419 // FIXME: Nintendo implemented heap sort manually without using <algorithm>
420 const auto cmp_ = [cmp](const T& a, const T& b) { return cmp(&a, &b) < 0; };
421 std::make_heap(mBuffer + start_idx, mBuffer + end_idx, cmp_);
422 std::sort_heap(mBuffer + start_idx, mBuffer + end_idx, cmp_);
423 }
424
425protected:
426 static s32 compareT(const T* lhs, const T* rhs)
427 {
428 if (*lhs < *rhs)
429 return -1;
430 if (*rhs < *lhs)
431 return 1;
432 return 0;
433 }
434
435 // This is duplicated from Mathi::abs to avoid having to include the MathCalcCommon header;
436 // this limits the number of files we have to rebuild downstream whenever maths code is updated.
437 static s32 abs(s32 x) { return (x ^ x >> 31) - (x >> 31); }
438
439 s32 mSize;
440 T* mBuffer;
441};
442
443} // namespace sead
444
445#endif // SEAD_BUFFER_H_
446