| 1 | #pragma once |
| 2 | |
| 3 | #include <basis/seadTypes.h> |
| 4 | #include <prim/seadMemUtil.h> |
| 5 | #include <random/seadRandom.h> |
| 6 | #include <thread/seadAtomic.h> |
| 7 | |
| 8 | namespace sead { |
| 9 | class Heap; |
| 10 | } |
| 11 | |
| 12 | namespace agl { |
| 13 | |
| 14 | namespace detail { |
| 15 | |
| 16 | class AtomicPtrArrayImpl { |
| 17 | public: |
| 18 | AtomicPtrArrayImpl() = default; |
| 19 | AtomicPtrArrayImpl(s32 ptrNumMax, void* buf) { setBuffer(ptrNumMax, buf); } |
| 20 | |
| 21 | void setBuffer(s32 ptrNumMax, void* buf); |
| 22 | void allocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*)); |
| 23 | void freeBuffer(); |
| 24 | bool isBufferReady() const { return mPtrs != nullptr; } |
| 25 | |
| 26 | bool isEmpty() const { return mPtrNum == 0; } |
| 27 | bool isFull() const { return mPtrNum == mPtrNumMax; } |
| 28 | |
| 29 | s32 size() const { return mPtrNum; } |
| 30 | s32 capacity() const { return mPtrNumMax; } |
| 31 | |
| 32 | void erase(s32 position) { erase(position, count: 1); } |
| 33 | void erase(s32 position, s32 count); |
| 34 | void clear() { mPtrNum.exchange(value: 0); } |
| 35 | |
| 36 | void swap(s32 pos1, s32 pos2) { |
| 37 | auto* ptr = mPtrs[pos1]; |
| 38 | mPtrs[pos1] = mPtrs[pos2]; |
| 39 | mPtrs[pos2] = ptr; |
| 40 | } |
| 41 | void shuffle() { |
| 42 | sead::Random random; |
| 43 | shuffle(random: &random); |
| 44 | } |
| 45 | void shuffle(sead::Random* random); |
| 46 | |
| 47 | protected: |
| 48 | using CompareCallbackImpl = int (*)(const void* a, const void* b); |
| 49 | |
| 50 | void* at(s32 idx) const { |
| 51 | if (u32(mPtrNum) <= u32(idx)) { |
| 52 | SEAD_ASSERT_MSG(false, "index exceeded [%d/%d]" , idx, mPtrNum.load()); |
| 53 | return nullptr; |
| 54 | } |
| 55 | return mPtrs[idx]; |
| 56 | } |
| 57 | |
| 58 | void* unsafeAt(s32 idx) const { return mPtrs[idx]; } |
| 59 | |
| 60 | // XXX: should this use at()? |
| 61 | void* front() const { return mPtrs[0]; } |
| 62 | void* back() const { return mPtrs[mPtrNum - 1]; } |
| 63 | |
| 64 | void pushBack(void* ptr) { |
| 65 | const s32 idx = mPtrNum++; |
| 66 | SEAD_ASSERT_MSG(idx < mPtrNumMax, "index = %d, mPtrNumMax = %d" , idx, mPtrNumMax.load()); |
| 67 | mPtrs[idx] = ptr; |
| 68 | } |
| 69 | |
| 70 | void* popBack() { return isEmpty() ? nullptr : mPtrs[--mPtrNum]; } |
| 71 | |
| 72 | void* popFront() { |
| 73 | if (isEmpty()) |
| 74 | return nullptr; |
| 75 | |
| 76 | void* result = mPtrs[0]; |
| 77 | erase(position: 0); |
| 78 | return result; |
| 79 | } |
| 80 | |
| 81 | void* find(const void* ptr, CompareCallbackImpl cmp) const { |
| 82 | for (s32 i = 0; i < mPtrNum; ++i) { |
| 83 | if (cmp(ptr, mPtrs[i]) == 0) |
| 84 | return mPtrs[i]; |
| 85 | } |
| 86 | return nullptr; |
| 87 | } |
| 88 | |
| 89 | s32 search(const void* ptr, CompareCallbackImpl cmp) const { |
| 90 | for (s32 i = 0; i < mPtrNum; ++i) { |
| 91 | if (cmp(ptr, mPtrs[i]) == 0) |
| 92 | return i; |
| 93 | } |
| 94 | return -1; |
| 95 | } |
| 96 | |
| 97 | bool equal(const AtomicPtrArrayImpl& other, CompareCallbackImpl cmp) const { |
| 98 | if (mPtrNum != other.mPtrNum) |
| 99 | return false; |
| 100 | |
| 101 | for (s32 i = 0; i < mPtrNum; ++i) { |
| 102 | if (cmp(mPtrs[i], other.mPtrs[i]) != 0) |
| 103 | return false; |
| 104 | } |
| 105 | return true; |
| 106 | } |
| 107 | |
| 108 | s32 indexOf(const void* ptr) const { |
| 109 | for (s32 i = 0; i < mPtrNum; ++i) { |
| 110 | if (mPtrs[i] == ptr) |
| 111 | return i; |
| 112 | } |
| 113 | return -1; |
| 114 | } |
| 115 | |
| 116 | void sort(CompareCallbackImpl cmp); |
| 117 | void heapSort(CompareCallbackImpl cmp); |
| 118 | |
| 119 | sead::Atomic<s32> mPtrNum = 0; |
| 120 | sead::Atomic<s32> mPtrNumMax = 0; |
| 121 | void** mPtrs = nullptr; |
| 122 | }; |
| 123 | |
| 124 | } // namespace detail |
| 125 | |
| 126 | namespace utl { |
| 127 | |
| 128 | template <typename T> |
| 129 | class AtomicPtrArray : public detail::AtomicPtrArrayImpl { |
| 130 | public: |
| 131 | AtomicPtrArray() = default; |
| 132 | AtomicPtrArray(s32 ptrNumMax, T** buf) : AtomicPtrArrayImpl(ptrNumMax, buf) {} |
| 133 | |
| 134 | T* at(s32 pos) const { return static_cast<T*>(AtomicPtrArrayImpl::at(idx: pos)); } |
| 135 | T* unsafeAt(s32 pos) const { return static_cast<T*>(AtomicPtrArrayImpl::unsafeAt(idx: pos)); } |
| 136 | T* operator[](s32 pos) const { return at(pos); } |
| 137 | |
| 138 | // XXX: Does this use at()? |
| 139 | T* front() const { return at(pos: 0); } |
| 140 | T* back() const { return at(pos: mPtrNum - 1); } |
| 141 | |
| 142 | void pushBack(T* ptr) { AtomicPtrArrayImpl::pushBack(ptr); } |
| 143 | |
| 144 | T* popBack() { return static_cast<T*>(AtomicPtrArrayImpl::popBack()); } |
| 145 | T* popFront() { return static_cast<T*>(AtomicPtrArrayImpl::popFront()); } |
| 146 | |
| 147 | s32 indexOf(const T* ptr) const { return AtomicPtrArrayImpl::indexOf(ptr); } |
| 148 | |
| 149 | using CompareCallback = s32 (*)(const T*, const T*); |
| 150 | |
| 151 | void sort() { AtomicPtrArrayImpl::sort(cmp: compareT); } |
| 152 | void sort(CompareCallback cmp) { AtomicPtrArrayImpl::sort(cmp); } |
| 153 | void heapSort() { AtomicPtrArrayImpl::heapSort(cmp: compareT); } |
| 154 | void heapSort(CompareCallback cmp) { AtomicPtrArrayImpl::heapSort(cmp); } |
| 155 | |
| 156 | bool equal(const AtomicPtrArray& other, CompareCallback cmp) const { |
| 157 | return AtomicPtrArrayImpl::equal(other, cmp); |
| 158 | } |
| 159 | |
| 160 | T* find(const T* ptr) const { return AtomicPtrArrayImpl::find(ptr, cmp: compareT); } |
| 161 | T* find(const T* ptr, CompareCallback cmp) const { return AtomicPtrArrayImpl::find(ptr, cmp); } |
| 162 | s32 search(const T* ptr) const { return AtomicPtrArrayImpl::search(ptr, cmp: compareT); } |
| 163 | s32 search(const T* ptr, CompareCallback cmp) const { |
| 164 | return AtomicPtrArrayImpl::search(ptr, cmp); |
| 165 | } |
| 166 | |
| 167 | bool operator==(const AtomicPtrArray& other) const { return equal(other, cmp: compareT); } |
| 168 | bool operator!=(const AtomicPtrArray& other) const { return !(*this == other); } |
| 169 | |
| 170 | class iterator { |
| 171 | public: |
| 172 | iterator(T* const* pptr) : mPPtr{pptr} {} |
| 173 | bool operator==(const iterator& other) const { return mPPtr == other.mPPtr; } |
| 174 | bool operator!=(const iterator& other) const { return !(*this == other); } |
| 175 | iterator& operator++() { |
| 176 | ++mPPtr; |
| 177 | return *this; |
| 178 | } |
| 179 | T& operator*() const { return **mPPtr; } |
| 180 | T* operator->() const { return *mPPtr; } |
| 181 | |
| 182 | private: |
| 183 | T* const* mPPtr; |
| 184 | }; |
| 185 | |
| 186 | iterator begin() const { return iterator(data()); } |
| 187 | iterator end() const { return iterator(data() + mPtrNum); } |
| 188 | |
| 189 | class constIterator { |
| 190 | public: |
| 191 | constIterator(const T* const* pptr) : mPPtr{pptr} {} |
| 192 | bool operator==(const constIterator& other) const { return mPPtr == other.mPPtr; } |
| 193 | bool operator!=(const constIterator& other) const { return !(*this == other); } |
| 194 | constIterator& operator++() { |
| 195 | ++mPPtr; |
| 196 | return *this; |
| 197 | } |
| 198 | const T& operator*() const { return **mPPtr; } |
| 199 | const T* operator->() const { return *mPPtr; } |
| 200 | |
| 201 | private: |
| 202 | const T* const* mPPtr; |
| 203 | }; |
| 204 | |
| 205 | constIterator constBegin() const { return constIterator(data()); } |
| 206 | constIterator constEnd() const { return constIterator(data() + mPtrNum); } |
| 207 | |
| 208 | T** data() const { return reinterpret_cast<T**>(mPtrs); } |
| 209 | |
| 210 | protected: |
| 211 | static int compareT(const T* a, const T* b) { |
| 212 | if (*a < *b) |
| 213 | return -1; |
| 214 | if (*a > *b) |
| 215 | return 1; |
| 216 | return 0; |
| 217 | } |
| 218 | }; |
| 219 | |
| 220 | template <typename T, s32 N> |
| 221 | class FixedPtrArray : public AtomicPtrArray<T> { |
| 222 | public: |
| 223 | FixedPtrArray() { AtomicPtrArray<T>::setBuffer(N, mWork); } |
| 224 | |
| 225 | // These do not make sense for a *fixed* array. |
| 226 | void setBuffer(s32 ptrNumMax, void* buf) = delete; |
| 227 | void allocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*)) = delete; |
| 228 | bool tryAllocBuffer(s32 ptrNumMax, sead::Heap* heap, s32 alignment = sizeof(void*)) = delete; |
| 229 | void freeBuffer() = delete; |
| 230 | |
| 231 | private: |
| 232 | // Nintendo uses an untyped u8[N*sizeof(void*)] buffer. That is undefined behavior, |
| 233 | // so we will not do that. |
| 234 | T* mWork[N]; |
| 235 | }; |
| 236 | |
| 237 | } // namespace utl |
| 238 | |
| 239 | } // namespace agl |
| 240 | |