1#pragma once
2
3#include <algorithm>
4#include "basis/seadNew.h"
5#include "basis/seadRawPrint.h"
6#include "container/seadFreeList.h"
7#include "container/seadPtrArray.h"
8
9namespace sead
10{
11/// An ObjArray is a container that allocates elements using a FreeList and also keeps an array of
12/// pointers for fast access to each element.
13template <typename T>
14class ObjArray : public PtrArrayImpl
15{
16public:
17 ObjArray() = default;
18 ObjArray(s32 max_num, void* buf) { setBuffer(max_num, buf); }
19
20 void allocBuffer(s32 capacity, Heap* heap, s32 alignment = sizeof(void*))
21 {
22 SEAD_ASSERT(mPtrs == nullptr);
23
24 if (capacity < 1)
25 {
26 SEAD_ASSERT_MSG(false, "capacity[%d] must be larger than zero", capacity);
27 return;
28 }
29
30 setBuffer(max_num: capacity,
31 buf: new (heap, alignment, std::nothrow) u8[calculateWorkBufferSize(n: capacity)]);
32 }
33
34 bool tryAllocBuffer(s32 capacity, Heap* heap, s32 alignment = sizeof(void*))
35 {
36 SEAD_ASSERT(mPtrs == nullptr);
37
38 if (capacity < 1)
39 {
40 SEAD_ASSERT_MSG(false, "capacity[%d] must be larger than zero", capacity);
41 return false;
42 }
43
44 auto* buf = new (heap, alignment, std::nothrow) u8[calculateWorkBufferSize(n: capacity)];
45 if (!buf)
46 return false;
47
48 setBuffer(max_num: capacity, buf);
49 return true;
50 }
51
52 void setBuffer(s32 max_num, void* buf)
53 {
54 if (!buf)
55 {
56 SEAD_ASSERT_MSG(false, "buf is null");
57 return;
58 }
59
60 mFreeList.setWork(work: buf, elem_size: ElementSize, num: max_num);
61 PtrArrayImpl::setBuffer(ptrNumMax: max_num, buf: reinterpret_cast<u8*>(buf) + ElementSize * max_num);
62 }
63
64 void freeBuffer()
65 {
66 if (!isBufferReady())
67 return;
68
69 clear();
70
71 if (mFreeList.work())
72 delete[] static_cast<u8*>(mFreeList.work());
73
74 mFreeList.reset();
75 mPtrs = nullptr;
76 mPtrNumMax = 0;
77 }
78
79 T* at(s32 pos) const { return static_cast<T*>(PtrArrayImpl::at(idx: pos)); }
80 T* unsafeAt(s32 pos) const { return static_cast<T*>(PtrArrayImpl::unsafeAt(idx: pos)); }
81 T* operator()(s32 pos) const { return unsafeAt(pos); }
82 T* operator[](s32 pos) const { return at(pos); }
83
84 // XXX: Does this use at()?
85 T* front() const { return at(pos: 0); }
86 T* back() const { return at(pos: mPtrNum - 1); }
87
88 void pushBack(const T& item)
89 {
90 if (isFull())
91 SEAD_ASSERT_MSG(false, "buffer full.");
92 else
93 PtrArrayImpl::pushBack(ptr: alloc(item));
94 }
95
96 template <class... Args>
97 T* emplaceBack(Args&&... args)
98 {
99 if (isFull())
100 {
101 SEAD_ASSERT_MSG(false, "buffer full.");
102 return nullptr;
103 }
104 T* item = new (mFreeList.alloc()) T(std::forward<Args>(args)...);
105 PtrArrayImpl::pushBack(ptr: item);
106 return item;
107 }
108
109 void insert(s32 pos, const T& item) { PtrArrayImpl::insert(idx: pos, ptr: alloc(item)); }
110
111 void erase(int index) { erase(index, 1); }
112
113 void erase(int index, int count)
114 {
115 if (index + count <= size())
116 {
117 for (int i = index; i < index + count; ++i)
118 {
119 auto* ptr = unsafeAt(pos: i);
120 ptr->~T();
121 mFreeList.free(ptr);
122 }
123 }
124 PtrArrayImpl::erase(index, count);
125 }
126
127 void clear()
128 {
129 for (s32 i = 0; i < mPtrNum; ++i)
130 {
131 auto* ptr = unsafeAt(pos: i);
132 ptr->~T();
133 mFreeList.free(ptr);
134 }
135 mPtrNum = 0;
136 }
137
138 using CompareCallback = s32 (*)(const T*, const T*);
139
140 void sort() { sort(compareT); }
141 void sort(CompareCallback cmp) { PtrArrayImpl::sort_<T>(cmp); }
142 void heapSort() { heapSort(compareT); }
143 void heapSort(CompareCallback cmp) { PtrArrayImpl::heapSort_<T>(cmp); }
144
145 bool equal(const ObjArray& other, CompareCallback cmp) const
146 {
147 return PtrArrayImpl::equal(other, cmp);
148 }
149
150 s32 compare(const ObjArray& other, CompareCallback cmp) const
151 {
152 return PtrArrayImpl::compare(other, cmp);
153 }
154
155 s32 binarySearch(const T* ptr) const { return PtrArrayImpl::binarySearch(ptr, cmp: compareT); }
156 s32 binarySearch(const T* ptr, CompareCallback cmp) const
157 {
158 return PtrArrayImpl::binarySearch(ptr, cmp);
159 }
160
161 bool operator==(const ObjArray& other) const { return equal(other, cmp: compareT); }
162 bool operator!=(const ObjArray& other) const { return !(*this == other); }
163 bool operator<(const ObjArray& other) const { return compare(other) < 0; }
164 bool operator<=(const ObjArray& other) const { return compare(other) <= 0; }
165 bool operator>(const ObjArray& other) const { return compare(other) > 0; }
166 bool operator>=(const ObjArray& other) const { return compare(other) >= 0; }
167
168 void uniq() { PtrArrayImpl::uniq(cmp: compareT); }
169 void uniq(CompareCallback cmp) { PtrArrayImpl::uniq(cmp); }
170
171 class iterator
172 {
173 public:
174 iterator(T* const* pptr) : mPPtr{pptr} {}
175 bool operator==(const iterator& other) const { return mPPtr == other.mPPtr; }
176 bool operator!=(const iterator& other) const { return !(*this == other); }
177 iterator& operator++()
178 {
179 ++mPPtr;
180 return *this;
181 }
182 T& operator*() const { return **mPPtr; }
183 T* operator->() const { return *mPPtr; }
184
185 private:
186 T* const* mPPtr;
187 };
188
189 iterator begin() const { return iterator(data()); }
190 iterator end() const { return iterator(data() + mPtrNum); }
191
192 class constIterator
193 {
194 public:
195 constIterator(const T* const* pptr) : mPPtr{pptr} {}
196 bool operator==(const constIterator& other) const { return mPPtr == other.mPPtr; }
197 bool operator!=(const constIterator& other) const { return !(*this == other); }
198 constIterator& operator++()
199 {
200 ++mPPtr;
201 return *this;
202 }
203 const T& operator*() const { return **mPPtr; }
204 const T* operator->() const { return *mPPtr; }
205
206 private:
207 const T* const* mPPtr;
208 };
209
210 constIterator constBegin() const { return constIterator(data()); }
211 constIterator constEnd() const { return constIterator(data() + mPtrNum); }
212
213 T** data() const { return reinterpret_cast<T**>(mPtrs); }
214
215private:
216 union Node
217 {
218 void* next_node;
219 T elem;
220 };
221
222public:
223 static constexpr size_t ElementSize = sizeof(Node);
224
225 static constexpr size_t calculateWorkBufferSize(size_t n)
226 {
227 return n * (ElementSize + sizeof(T*));
228 }
229
230protected:
231 T* alloc(const T& item)
232 {
233 void* storage = mFreeList.alloc();
234 if (!storage)
235 return nullptr;
236 return new (storage) T(item);
237 }
238
239 static int compareT(const void* a_, const void* b_)
240 {
241 const T* a = static_cast<const T*>(a_);
242 const T* b = static_cast<const T*>(b_);
243 if (*a < *b)
244 return -1;
245 if (*b < *a)
246 return 1;
247 return 0;
248 }
249
250 sead::FreeList mFreeList;
251};
252
253template <typename T, s32 N>
254class FixedObjArray : public ObjArray<T>
255{
256public:
257 FixedObjArray() : ObjArray<T>(N, &mWork) {}
258
259 // These do not make sense for a *fixed* array.
260 void setBuffer(s32 ptrNumMax, void* buf) = delete;
261 void allocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
262 bool tryAllocBuffer(s32 ptrNumMax, Heap* heap, s32 alignment = sizeof(void*)) = delete;
263 void freeBuffer() = delete;
264
265private:
266 std::aligned_storage_t<ObjArray<T>::calculateWorkBufferSize(N),
267 std::max(a: alignof(T), b: alignof(T*))>
268 mWork;
269};
270
271} // namespace sead
272