1#pragma once
2
3#include <basis/seadRawPrint.h>
4#include <basis/seadTypes.h>
5
6namespace sead
7{
8class BitFlagUtil
9{
10public:
11 /// Popcount.
12 static int countOnBit(u32 x);
13 /// Count trailing zeroes (ctz).
14 static int countContinuousOffBitFromRight(u32 x) { return countOnBit(x: (x & -x) - 1); }
15 static int countRightOnBit(u32 x, int bit);
16 static int findOnBitFromRight(u32 x, int num);
17
18 /// Popcount.
19 static int countOnBit64(u64 x)
20 {
21 return countOnBit(x: static_cast<u32>(x)) + countOnBit(x: static_cast<u32>(x >> 32));
22 }
23 /// Count trailing zeroes (ctz).
24 static int countContinuousOffBitFromRight64(u64 x) { return countOnBit64(x: (x & -x) - 1); }
25 static int countRightOnBit64(u64 x, int bit);
26 static int findOnBitFromRight64(u64 x, int num);
27};
28
29template <typename T>
30class BitFlag
31{
32public:
33 BitFlag() : mBits(0) {}
34 BitFlag(T bits) : mBits(bits) {}
35
36 operator T() const { return mBits; }
37
38 void makeAllZero() { mBits = 0; }
39 void makeAllOne() { mBits = ~T(0); }
40
41 void setDirect(T bits) { mBits = bits; }
42 T getDirect() const { return mBits; }
43 T* getPtr() { return &mBits; }
44 const T* getPtr() const { return &mBits; }
45 size_t getByteSize() const { return sizeof(T); }
46
47 void set(T val) { mBits |= val; }
48 void reset(T val) { mBits &= ~val; }
49 void toggle(T val) { mBits ^= val; }
50 void change(T val, bool on)
51 {
52 if (on)
53 set(val);
54 else
55 reset(val);
56 }
57 bool isZero() const { return mBits == 0; }
58 /// Checks if (at least) one of the bits are set.
59 bool isOn(T val) const { return (mBits & val) != 0; }
60 /// Checks if all of the bits are set.
61 bool isOnAll(T val) const { return (mBits & val) == val; }
62 bool isOff(T val) const { return !isOn(val); }
63
64 bool testAndClear(T val)
65 {
66 if (!isOn(val))
67 return false;
68 reset(val);
69 return true;
70 }
71
72 // TODO: what is this even supposed to do? (This function is known to exist
73 // because it is present in debug info for pead in Super Mario Run.)
74 T getMask(T v) const;
75
76 static T makeMask(int bit) { return T(1) << bit; }
77
78 void setBit(int bit) { set(makeMask(bit)); }
79 void resetBit(int bit) { reset(val: makeMask(bit)); }
80 void changeBit(int bit, bool on) { change(val: makeMask(bit), on); }
81 void toggleBit(int bit) { toggle(val: makeMask(bit)); }
82 bool isOnBit(int bit) const { return isOn(val: makeMask(bit)); }
83 bool isOffBit(int bit) const { return isOff(val: makeMask(bit)); }
84 bool testAndClearBit(int bit) { return testAndClear(val: makeMask(bit)); }
85
86 /// Popcount.
87 int countOnBit() const
88 {
89 if constexpr (sizeof(T) <= 4)
90 return BitFlagUtil::countOnBit(x: mBits);
91 else
92 return BitFlagUtil::countOnBit64(x: mBits);
93 }
94 /// Count trailing zeroes.
95 int countContinuousOffBitFromRight() const
96 {
97 if constexpr (sizeof(T) <= 4)
98 return BitFlagUtil::countContinuousOffBitFromRight(x: mBits);
99 else
100 return BitFlagUtil::countContinuousOffBitFromRight64(x: mBits);
101 }
102 int countRightOnBit(int bit) const
103 {
104 if constexpr (sizeof(T) <= 4)
105 return BitFlagUtil::countRightOnBit(x: mBits, bit);
106 else
107 return BitFlagUtil::countRightOnBit64(x: mBits, bit);
108 }
109 int findOnBitFromRight(int num) const
110 {
111 if constexpr (sizeof(T) <= 4)
112 return BitFlagUtil::findOnBitFromRight(x: mBits, num);
113 else
114 return BitFlagUtil::findOnBitFromRight64(x: mBits, num);
115 }
116
117protected:
118 T mBits;
119};
120
121using BitFlag8 = BitFlag<u8>;
122using BitFlag16 = BitFlag<u16>;
123using BitFlag32 = BitFlag<u32>;
124using BitFlag64 = BitFlag<u64>;
125
126} // namespace sead
127