1#pragma once
2
3#include <nn/types.h>
4#include <type_traits>
5
6namespace nn::util {
7
8namespace detail {
9
10template <typename Integer>
11void AndEqual(Integer* dest, const Integer* source, int count) {
12 for (int i = 0; i < count; ++i) {
13 dest[i] &= source[i];
14 }
15}
16
17template <typename Integer>
18void XorEqual(Integer* dest, const Integer* source, int count) {
19 for (int i = 0; i < count; ++i) {
20 dest[i] ^= source[i];
21 }
22}
23
24template <typename Integer>
25void OrEqual(Integer* dest, const Integer* source, int count) {
26 for (int i = 0; i < count; ++i) {
27 dest[i] |= source[i];
28 }
29}
30
31template <typename Integer>
32bool Equals(const Integer* dest, const Integer* source, int count) {
33 for (int i = 0; i < count; ++i) {
34 if (dest[i] != source[i]) {
35 return false;
36 }
37 }
38
39 return true;
40}
41
42template <typename Integer>
43bool IsAnyOn(const Integer* _storage, int count) {
44 for (int i = 0; i < count; ++i) {
45 if (_storage[i] != 0) {
46 return true;
47 }
48 }
49
50 return false;
51}
52
53} // namespace detail
54
55template <int N, typename Tag>
56struct BitFlagSet {
57 typedef typename std::conditional<N <= 32, uint32_t, uint64_t>::type StorageT;
58
59 static const int StorageBitCount = 8 * sizeof(StorageT);
60 static const int StorageCount = (N + StorageBitCount - 1) / StorageBitCount;
61 StorageT _storage[StorageCount]{};
62
63 class Reference {
64 public:
65 Reference& operator=(bool isOn) {
66 m_Set->Set(m_Index, isOn);
67 return *this;
68 }
69
70 Reference& operator=(const Reference& ref) {
71 m_Set->Set(m_Index, ref);
72 return *this;
73 }
74
75 Reference& Flip() {
76 m_Set->Flip(m_Index);
77 return *this;
78 }
79
80 bool operator~() const { return !m_Set->Test(m_Index); }
81
82 operator bool() const { return m_Set->Test(m_Index); }
83
84 private:
85 BitFlagSet* m_Set;
86 int m_Index;
87
88 Reference() : m_Set(nullptr), m_Index(0) {}
89 Reference(BitFlagSet& set, int index) : m_Set(&set), m_Index(index) {}
90 };
91
92 BitFlagSet operator~() const {
93 BitFlagSet tmp = *this;
94 tmp.Flip();
95 return tmp;
96 }
97
98 BitFlagSet operator&(const BitFlagSet& other) const {
99 BitFlagSet value = *this;
100 value &= other;
101 return value;
102 }
103
104 BitFlagSet operator^(const BitFlagSet& other) const {
105 BitFlagSet value = *this;
106 value ^= other;
107 return value;
108 }
109
110 BitFlagSet operator|(const BitFlagSet& other) const {
111 BitFlagSet value = *this;
112 value |= other;
113 return value;
114 }
115
116 BitFlagSet& operator&=(const BitFlagSet& other) {
117 detail::AndEqual(_storage, other._storage, StorageCount);
118 return *this;
119 }
120
121 BitFlagSet& operator^=(const BitFlagSet& other) {
122 detail::XorEqual(_storage, other._storage, StorageCount);
123 return *this;
124 }
125
126 BitFlagSet& operator|=(const BitFlagSet& other) {
127 detail::OrEqual(_storage, other._storage, StorageCount);
128 return *this;
129 }
130
131 bool operator==(const BitFlagSet& other) const {
132 return detail::Equals(_storage, other._storage, StorageCount);
133 }
134
135 bool operator!=(const BitFlagSet& other) const {
136 return !detail::Equals(_storage, other._storage, StorageCount);
137 }
138
139 bool operator[](int index) const { return Test(index); }
140 Reference operator[](int index) { return Reference(*this, index); }
141
142 bool IsAnyOn() const { return detail::IsAnyOn(_storage, StorageCount); }
143
144 // https://en.wikichip.org/wiki/population_count
145 int CountPopulation() const {
146 int c = 0;
147 for (int i = 0; i < StorageCount; ++i) {
148 StorageT x = _storage[i];
149 for (; x != 0; x &= x - 1)
150 c++;
151 }
152 return c;
153 }
154
155 BitFlagSet& Flip(int index) { return Set(index, !Test(index)); }
156
157 BitFlagSet& Flip() {
158 for (int i = 0; i < StorageCount; ++i) {
159 _storage[i] = ~_storage[i];
160 }
161 Truncate();
162
163 return *this;
164 }
165
166 bool IsAllOn() const { return CountPopulation() == N; }
167 bool IsAllOff() const { return CountPopulation() == 0; }
168
169 BitFlagSet& Reset() {
170 for (int i = 0; i < StorageCount; ++i) {
171 _storage[i] = 0;
172 }
173 return *this;
174 }
175
176 BitFlagSet& Reset(int index) { return Set(index, false); }
177
178 BitFlagSet& Set() {
179 for (int i = 0; i < StorageCount; ++i) {
180 _storage[i] = ~0;
181 }
182 Truncate();
183
184 return *this;
185 }
186
187 BitFlagSet& Set(int index, bool isOn = true) {
188 // todo: add assert to verify index is valid
189 return SetImpl(storageIndex: GetStorageIndex(index), storageMask: MakeStorageMask(index), isOn);
190 }
191
192 template <typename FlagT>
193 BitFlagSet& Set(bool isOn = true) const {
194 return SetImpl(storageIndex: FlagT::StorageIndex, storageMask: FlagT::StorageMask, isOn);
195 }
196
197 int GetCount() const { return N; }
198
199 bool Test(int index) const {
200 // todo: add assert to verify index is valid
201 return TestImpl(storageIndex: GetStorageIndex(index), storageMask: MakeStorageMask(index));
202 }
203
204 template <typename FlagT>
205 bool Test() const {
206 return TestImpl(storageIndex: FlagT::StorageIndex, storageMask: FlagT::StorageMask);
207 }
208
209 template <int BitIndex>
210 struct Flag {
211 static_assert(BitIndex < N, "BitIndex < N");
212
213 static constexpr BitFlagSet buildMask() {
214 BitFlagSet tmp;
215 tmp._storage[StorageIndex] = StorageMask;
216 return tmp;
217 }
218
219 static constexpr int Index = BitIndex;
220 static constexpr BitFlagSet Mask = buildMask();
221
222 private:
223 static constexpr int StorageIndex = BitIndex / StorageBitCount;
224 static constexpr StorageT StorageMask = StorageT(1) << (BitIndex % StorageBitCount);
225 };
226
227private:
228 BitFlagSet& SetImpl(int storageIndex, StorageT storageMask, bool isOn) {
229 if (isOn) {
230 _storage[storageIndex] |= storageMask;
231 } else {
232 _storage[storageIndex] &= ~storageMask;
233 }
234 return *this;
235 }
236
237 bool TestImpl(int storageIndex, StorageT storageMask) const {
238 return _storage[storageIndex] & storageMask;
239 }
240
241 void Truncate() { TruncateIf(std::integral_constant<bool, (N % StorageBitCount) != 0>{}); }
242
243 void TruncateIf(std::true_type) { _storage[StorageCount - 1] &= MakeStorageMask(index: N) - 1; }
244 void TruncateIf(std::false_type) {}
245
246 static int GetStorageIndex(int index) { return index / StorageBitCount; }
247 static StorageT MakeStorageMask(int index) { return StorageT(1) << (index % StorageBitCount); }
248};
249
250} // namespace nn::util
251