1#pragma once
2
3#include <basis/seadTypes.h>
4#include <prim/seadMemUtil.h>
5#include <random/seadRandom.h>
6#include <thread/seadAtomic.h>
7
8namespace sead {
9class Heap;
10}
11
12namespace agl {
13
14namespace detail {
15
16class AtomicPtrArrayImpl {
17public:
18 AtomicPtrArrayImpl() = default;
19 AtomicPtrArrayImpl(s32 ptrNumMax, void* buf) { setBuffer(ptrNumMax, buf); }
20
21 void setBuffer(s32 ptrNumMax, void* buf);
22 void allocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*));
23 void freeBuffer();
24 bool isBufferReady() const { return mPtrs != nullptr; }
25
26 bool isEmpty() const { return mPtrNum == 0; }
27 bool isFull() const { return mPtrNum == mPtrNumMax; }
28
29 s32 size() const { return mPtrNum; }
30 s32 capacity() const { return mPtrNumMax; }
31
32 void erase(s32 position) { erase(position, count: 1); }
33 void erase(s32 position, s32 count);
34 void clear() { mPtrNum.exchange(value: 0); }
35
36 void swap(s32 pos1, s32 pos2) {
37 auto* ptr = mPtrs[pos1];
38 mPtrs[pos1] = mPtrs[pos2];
39 mPtrs[pos2] = ptr;
40 }
41 void shuffle() {
42 sead::Random random;
43 shuffle(random: &random);
44 }
45 void shuffle(sead::Random* random);
46
47protected:
48 using CompareCallbackImpl = int (*)(const void* a, const void* b);
49
50 void* at(s32 idx) const {
51 if (u32(mPtrNum) <= u32(idx)) {
52 SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]", idx, mPtrNum.load());
53 return nullptr;
54 }
55 return mPtrs[idx];
56 }
57
58 void* unsafeAt(s32 idx) const { return mPtrs[idx]; }
59
60 // XXX: should this use at()?
61 void* front() const { return mPtrs[0]; }
62 void* back() const { return mPtrs[mPtrNum - 1]; }
63
64 void pushBack(void* ptr) {
65 const s32 idx = mPtrNum++;
66 SEAD_ASSERT_MSG(idx < mPtrNumMax, "index = %d, mPtrNumMax = %d", idx, mPtrNumMax.load());
67 mPtrs[idx] = ptr;
68 }
69
70 void* popBack() { return isEmpty() ? nullptr : mPtrs[--mPtrNum]; }
71
72 void* popFront() {
73 if (isEmpty())
74 return nullptr;
75
76 void* result = mPtrs[0];
77 erase(position: 0);
78 return result;
79 }
80
81 void* find(const void* ptr, CompareCallbackImpl cmp) const {
82 for (s32 i = 0; i < mPtrNum; ++i) {
83 if (cmp(ptr, mPtrs[i]) == 0)
84 return mPtrs[i];
85 }
86 return nullptr;
87 }
88
89 s32 search(const void* ptr, CompareCallbackImpl cmp) const {
90 for (s32 i = 0; i < mPtrNum; ++i) {
91 if (cmp(ptr, mPtrs[i]) == 0)
92 return i;
93 }
94 return -1;
95 }
96
97 bool equal(const AtomicPtrArrayImpl& other, CompareCallbackImpl cmp) const {
98 if (mPtrNum != other.mPtrNum)
99 return false;
100
101 for (s32 i = 0; i < mPtrNum; ++i) {
102 if (cmp(mPtrs[i], other.mPtrs[i]) != 0)
103 return false;
104 }
105 return true;
106 }
107
108 s32 indexOf(const void* ptr) const {
109 for (s32 i = 0; i < mPtrNum; ++i) {
110 if (mPtrs[i] == ptr)
111 return i;
112 }
113 return -1;
114 }
115
116 void sort(CompareCallbackImpl cmp);
117 void heapSort(CompareCallbackImpl cmp);
118
119 sead::Atomic<s32> mPtrNum = 0;
120 sead::Atomic<s32> mPtrNumMax = 0;
121 void** mPtrs = nullptr;
122};
123
124} // namespace detail
125
126namespace utl {
127
128template <typename T>
129class AtomicPtrArray : public detail::AtomicPtrArrayImpl {
130public:
131 AtomicPtrArray() = default;
132 AtomicPtrArray(s32 ptrNumMax, T** buf) : AtomicPtrArrayImpl(ptrNumMax, buf) {}
133
134 T* at(s32 pos) const { return static_cast<T*>(AtomicPtrArrayImpl::at(idx: pos)); }
135 T* unsafeAt(s32 pos) const { return static_cast<T*>(AtomicPtrArrayImpl::unsafeAt(idx: pos)); }
136 T* operator[](s32 pos) const { return at(pos); }
137
138 // XXX: Does this use at()?
139 T* front() const { return at(pos: 0); }
140 T* back() const { return at(pos: mPtrNum - 1); }
141
142 void pushBack(T* ptr) { AtomicPtrArrayImpl::pushBack(ptr); }
143
144 T* popBack() { return static_cast<T*>(AtomicPtrArrayImpl::popBack()); }
145 T* popFront() { return static_cast<T*>(AtomicPtrArrayImpl::popFront()); }
146
147 s32 indexOf(const T* ptr) const { return AtomicPtrArrayImpl::indexOf(ptr); }
148
149 using CompareCallback = s32 (*)(const T*, const T*);
150
151 void sort() { AtomicPtrArrayImpl::sort(cmp: compareT); }
152 void sort(CompareCallback cmp) { AtomicPtrArrayImpl::sort(cmp); }
153 void heapSort() { AtomicPtrArrayImpl::heapSort(cmp: compareT); }
154 void heapSort(CompareCallback cmp) { AtomicPtrArrayImpl::heapSort(cmp); }
155
156 bool equal(const AtomicPtrArray& other, CompareCallback cmp) const {
157 return AtomicPtrArrayImpl::equal(other, cmp);
158 }
159
160 T* find(const T* ptr) const { return AtomicPtrArrayImpl::find(ptr, cmp: compareT); }
161 T* find(const T* ptr, CompareCallback cmp) const { return AtomicPtrArrayImpl::find(ptr, cmp); }
162 s32 search(const T* ptr) const { return AtomicPtrArrayImpl::search(ptr, cmp: compareT); }
163 s32 search(const T* ptr, CompareCallback cmp) const {
164 return AtomicPtrArrayImpl::search(ptr, cmp);
165 }
166
167 bool operator==(const AtomicPtrArray& other) const { return equal(other, cmp: compareT); }
168 bool operator!=(const AtomicPtrArray& other) const { return !(*this == other); }
169
170 class iterator {
171 public:
172 iterator(T* const* pptr) : mPPtr{pptr} {}
173 bool operator==(const iterator& other) const { return mPPtr == other.mPPtr; }
174 bool operator!=(const iterator& other) const { return !(*this == other); }
175 iterator& operator++() {
176 ++mPPtr;
177 return *this;
178 }
179 T& operator*() const { return **mPPtr; }
180 T* operator->() const { return *mPPtr; }
181
182 private:
183 T* const* mPPtr;
184 };
185
186 iterator begin() const { return iterator(data()); }
187 iterator end() const { return iterator(data() + mPtrNum); }
188
189 class constIterator {
190 public:
191 constIterator(const T* const* pptr) : mPPtr{pptr} {}
192 bool operator==(const constIterator& other) const { return mPPtr == other.mPPtr; }
193 bool operator!=(const constIterator& other) const { return !(*this == other); }
194 constIterator& operator++() {
195 ++mPPtr;
196 return *this;
197 }
198 const T& operator*() const { return **mPPtr; }
199 const T* operator->() const { return *mPPtr; }
200
201 private:
202 const T* const* mPPtr;
203 };
204
205 constIterator constBegin() const { return constIterator(data()); }
206 constIterator constEnd() const { return constIterator(data() + mPtrNum); }
207
208 T** data() const { return reinterpret_cast<T**>(mPtrs); }
209
210protected:
211 static int compareT(const T* a, const T* b) {
212 if (*a < *b)
213 return -1;
214 if (*a > *b)
215 return 1;
216 return 0;
217 }
218};
219
220template <typename T, s32 N>
221class FixedPtrArray : public AtomicPtrArray<T> {
222public:
223 FixedPtrArray() { AtomicPtrArray<T>::setBuffer(N, mWork); }
224
225 // These do not make sense for a *fixed* array.
226 void setBuffer(s32 ptrNumMax, void* buf) = delete;
227 void allocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*)) = delete;
228 bool tryAllocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*)) = delete;
229 void freeBuffer() = delete;
230
231private:
232 // Nintendo uses an untyped u8[N*sizeof(void*)] buffer. That is undefined behavior,
233 // so we will not do that.
234 T* mWork[N];
235};
236
237} // namespace utl
238
239} // namespace agl
240