1#pragma once
2
3#include <cstdarg>
4
5#include <basis/seadRawPrint.h>
6#include <basis/seadTypes.h>
7
8namespace sead
9{
10class Heap;
11
12template <typename T>
13class BufferedSafeStringBase;
14
15template <typename T>
16class SafeStringBase
17{
18public:
19 /// Iterates over every character of a string.
20 /// Note that this is extremely inefficient and leads to quadratic time complexity
21 /// because of the redundant calls to calcLength() in operator*.
22 class iterator
23 {
24 public:
25 explicit iterator(const SafeStringBase* string) : mString(string), mIndex(0) {}
26 iterator(const SafeStringBase* string, s32 index) : mString(string), mIndex(index) {}
27 bool operator==(const iterator& rhs) const
28 {
29 return mString == rhs.mString && mIndex == rhs.mIndex;
30 }
31 bool operator!=(const iterator& rhs) const { return !(rhs == *this); }
32 iterator& operator++() { return mIndex++, *this; }
33 iterator& operator--() { return mIndex--, *this; }
34 const char& operator*() const { return mString->at(idx: mIndex); }
35
36 const SafeStringBase* getString() const { return mString; }
37 s32 getIndex() const { return mIndex; }
38
39 protected:
40 const SafeStringBase* mString;
41 s32 mIndex;
42 };
43
44 /// Iterates over a string as if it were split by one or several delimiter characters.
45 class token_iterator : public iterator
46 {
47 public:
48 token_iterator(const SafeStringBase* string, const SafeStringBase& delimiter)
49 : iterator(string), mDelimiter(delimiter)
50 {
51 }
52
53 token_iterator(const SafeStringBase* string, s32 index, const SafeStringBase& delimiter)
54 : iterator(string, index), mDelimiter(delimiter)
55 {
56 }
57
58 bool operator==(const token_iterator& rhs) const
59 {
60 return static_cast<const iterator&>(*this) == static_cast<const iterator&>(rhs);
61 }
62 bool operator!=(const token_iterator& rhs) const { return !(rhs == *this); }
63
64 token_iterator& operator++();
65 token_iterator& operator--();
66
67 s32 get(BufferedSafeStringBase<T>* out) const;
68 inline s32 getAndForward(BufferedSafeStringBase<T>* out);
69 s32 cutOffGet(BufferedSafeStringBase<T>* out) const;
70 s32 cutOffGetAndForward(BufferedSafeStringBase<T>* out);
71
72 private:
73 const SafeStringBase mDelimiter;
74 };
75
76 SafeStringBase() : mStringTop(&cNullChar) {}
77 SafeStringBase(const T* str) : mStringTop(str)
78 {
79 SEAD_ASSERT_MSG(str != nullptr, "str must not be nullptr.");
80 }
81 SafeStringBase(const SafeStringBase& other) = default;
82
83 virtual ~SafeStringBase() = default;
84
85 virtual SafeStringBase& operator=(const SafeStringBase& other);
86
87 bool operator==(const SafeStringBase& rhs) const { return isEqual(str: rhs); }
88 bool operator!=(const SafeStringBase& rhs) const { return !(*this == rhs); }
89
90 iterator begin() const { return iterator(this, 0); }
91 iterator end() const { return iterator(this, calcLength()); }
92
93 token_iterator tokenBegin(const SafeStringBase& delimiter) const
94 {
95 return token_iterator(this, delimiter);
96 }
97
98 token_iterator tokenEnd(const SafeStringBase& delimiter) const
99 {
100 return token_iterator(this, calcLength() + 1, delimiter);
101 }
102
103 const T* cstr() const
104 {
105 assureTerminationImpl_();
106 return mStringTop;
107 }
108
109 const T* getStringTop() const { return mStringTop; }
110
111 inline const T& at(s32 idx) const;
112 inline const T& operator[](s32 idx) const { return at(idx); }
113
114 inline s32 calcLength() const;
115
116 inline SafeStringBase<T> getPart(s32 at) const;
117 inline SafeStringBase<T> getPart(const iterator& it) const;
118 inline SafeStringBase<T> getPart(const token_iterator& it) const;
119
120 inline bool include(const T& c) const;
121 inline bool include(const SafeStringBase<T>& str) const;
122
123 bool isEqual(const SafeStringBase<T>& str) const;
124 inline s32 compare(const SafeStringBase<T>& str) const { return comparen(str, n: cMaximumLength); }
125 inline s32 comparen(const SafeStringBase<T>& str, s32 n) const;
126
127 s32 findIndex(const SafeStringBase<T>& str) const;
128 s32 findIndex(const SafeStringBase<T>& str, s32 start_pos) const;
129 s32 rfindIndex(const SafeStringBase<T>& str) const;
130
131 iterator findIterator(const SafeStringBase& str) const { return {this, findIndex(str)}; }
132 iterator rfindIterator(const SafeStringBase& str) const { return {this, rfindIndex(str)}; }
133
134 bool isEmpty() const;
135 bool startsWith(const SafeStringBase<T>& prefix) const;
136 bool endsWith(const SafeStringBase<T>& suffix) const;
137
138 static const T cNullChar;
139 static const T cLineBreakChar;
140 static const SafeStringBase cEmptyString;
141 static const s32 cMaximumLength = 0x80000;
142
143protected:
144 virtual void assureTerminationImpl_() const {}
145 const T& unsafeAt_(s32 idx) const { return mStringTop[idx]; }
146
147 const T* mStringTop;
148};
149
150template <>
151const SafeStringBase<char> SafeStringBase<char>::cEmptyString;
152
153template <typename T>
154s32 replaceStringImpl_(T* dst, s32* length, s32 dst_size, const T* src, s32 src_size,
155 const SafeStringBase<T>& old_str, const SafeStringBase<T>& new_str,
156 bool* is_buffer_overflow);
157
158template <typename T>
159class BufferedSafeStringBase : public SafeStringBase<T>
160{
161public:
162 __attribute__((always_inline)) BufferedSafeStringBase(T* buffer, s32 size)
163 : SafeStringBase<T>(buffer)
164 {
165 mBufferSize = size;
166 if (size <= 0)
167 {
168 SEAD_ASSERT_MSG(false, "Invalied buffer size(%d).\n", this->getBufferSize());
169 this->mStringTop = nullptr;
170 this->mBufferSize = 0;
171 }
172 else
173 {
174 this->assureTerminationImpl_();
175 }
176 }
177
178 BufferedSafeStringBase(const BufferedSafeStringBase&) = default;
179 ~BufferedSafeStringBase() override = default;
180
181 BufferedSafeStringBase<T>& operator=(const SafeStringBase<T>& other) override;
182
183 const T& operator[](s32 idx) const;
184
185 T* getBuffer()
186 {
187 assureTerminationImpl_();
188 return getMutableStringTop_();
189 }
190 s32 getBufferSize() const { return mBufferSize; }
191
192 /// Copy up to copyLength characters to the beginning of the string, then writes NUL.
193 /// @param src Source string
194 /// @param copyLength Number of characters from src to copy (must not cause a buffer overflow)
195 inline s32 copy(const SafeStringBase<T>& src, s32 copyLength = -1);
196 /// Copy up to copyLength characters to the specified position, then writes NUL if the copy
197 /// makes this string longer.
198 /// @param at Start position (-1 for end of string)
199 /// @param src Source string
200 /// @param copyLength Number of characters from src to copy (must not cause a buffer overflow)
201 inline s32 copyAt(s32 at, const SafeStringBase<T>& src, s32 copyLength = -1);
202 /// Copy up to copyLength characters to the beginning of the string, then writes NUL.
203 /// Silently truncates the source string if the buffer is too small.
204 /// @param src Source string
205 /// @param copyLength Number of characters from src to copy
206 inline s32 cutOffCopy(const SafeStringBase<T>& src, s32 copyLength = -1);
207 /// Copy up to copyLength characters to the specified position, then writes NUL if the copy
208 /// makes this string longer.
209 /// Silently truncates the source string if the buffer is too small.
210 /// @param at Start position (-1 for end of string)
211 /// @param src Source string
212 /// @param copyLength Number of characters from src to copy
213 inline s32 cutOffCopyAt(s32 at, const SafeStringBase<T>& src, s32 copyLength = -1);
214 /// Copy up to copyLength characters to the specified position, then *always* writes NUL.
215 /// @param at Start position (-1 for end of string)
216 /// @param src Source string
217 /// @param copyLength Number of characters from src to copy (must not cause a buffer overflow)
218 inline s32 copyAtWithTerminate(s32 at, const SafeStringBase<T>& src, s32 copyLength = -1);
219
220 s32 format(const T* format, ...);
221 s32 formatV(const T* format, std::va_list args);
222 s32 appendWithFormat(const T* formatStr, ...);
223 s32 appendWithFormatV(const T* formatStr, std::va_list args);
224
225 /// Append append_length characters from str.
226 s32 append(const SafeStringBase<T>& str, s32 append_length = -1);
227 /// Append a character.
228 s32 append(T c);
229 /// Append a character n times.
230 s32 append(T c, s32 n);
231
232 // Implementation note: These member functions appear to be inlined in most titles.
233 // However, StringBuilderBase<T> conveniently duplicates the APIs and implementations of
234 // SafeStringBase<T> and BufferedSafeString<T>: some assertion messages are even identical,
235 // and the good news is that most StringBuilderBase<T> functions are not inlined!
236
237 /// Append prepend_length characters from str.
238 /// @return the new length
239 s32 prepend(const SafeStringBase<T>& str, s32 prepend_length = -1);
240
241 /// Remove num characters from the end of the string.
242 /// @return the number of characters that were removed
243 s32 chop(s32 num);
244 /// Remove the last character if it is equal to c.
245 /// @return the number of characters that were removed
246 s32 chopMatchedChar(T c);
247 /// Remove the last character if it is equal to any of the specified characters.
248 /// @param characters List of characters to remove
249 /// @return the number of characters that were removed
250 s32 chopMatchedChar(const T* characters);
251 /// Remove the last character if it is unprintable.
252 /// @warning The behavior of this function is not standard: a character is considered
253 /// unprintable if it is <= 0x20 or == 0x7F. In particular, the space character is unprintable.
254 /// @return the number of characters that were removed
255 s32 chopUnprintableAsciiChar();
256
257 /// Remove trailing characters that are in the specified list.
258 /// @param characters List of characters to remove
259 /// @return the number of characters that were removed
260 s32 rstrip(const T* characters);
261 /// Remove trailing characters that are unprintable.
262 /// @warning The behavior of this function is not standard: a character is considered
263 /// unprintable if it is <= 0x20 or == 0x7F. In particular, the space character is unprintable.
264 /// @return the number of characters that were removed
265 s32 rstripUnprintableAsciiChars();
266
267 /// Trim a string to only keep trimLength characters.
268 /// @return the new length
269 inline s32 trim(s32 trim_length);
270 /// Remove the specified suffix from the string if it ends with the suffix.
271 /// @return the new length
272 inline s32 trimMatchedString(const SafeStringBase<T>& suffix);
273 /// Remove the specified suffix from the string if it ends with the suffix.
274 /// Alias of trimMatchedString.
275 /// @return the new length
276 inline s32 removeSuffix(const SafeStringBase<T>& suffix) { return trimMatchedString(suffix); }
277
278 /// @return the number of characters that were replaced
279 inline s32 replaceChar(T old_char, T new_char);
280 /// @return the number of characters that were replaced
281 inline s32 replaceCharList(const SafeStringBase<T>& old_chars,
282 const SafeStringBase<T>& new_chars);
283 /// Set the contents of this string to target_str, after replacing occurrences of old_str in
284 /// target_str with new_str.
285 /// @return the number of replaced occurrences
286 inline s32 setReplaceString(const SafeStringBase<T>& target_str,
287 const SafeStringBase<T>& old_str, const SafeStringBase<T>& new_str);
288 /// Replace occurrences of old_str in this string with new_str.
289 /// @return the number of replaced occurrences
290 inline s32 replaceString(const SafeStringBase<T>& old_str, const SafeStringBase<T>& new_str);
291
292 s32 convertFromMultiByteString(const SafeStringBase<char>& str, s32 str_length);
293 s32 convertFromWideCharString(const SafeStringBase<char16>& str, s32 str_length);
294
295 inline void clear() { getMutableStringTop_()[0] = this->cNullChar; }
296
297protected:
298 void assureTerminationImpl_() const override;
299
300 T* getMutableStringTop_() { return const_cast<T*>(this->mStringTop); }
301
302 static s32 formatImpl_(T* dst, s32 dst_size, const T* format, std::va_list arg);
303
304 template <typename OtherType>
305 s32 convertFromOtherType_(const SafeStringBase<OtherType>& src, s32 src_size);
306
307 s32 mBufferSize;
308};
309
310template <typename T, s32 L>
311class FixedSafeStringBase : public BufferedSafeStringBase<T>
312{
313public:
314 FixedSafeStringBase() : BufferedSafeStringBase<T>(mBuffer, L) { this->clear(); }
315
316 FixedSafeStringBase(const SafeStringBase<T>& str) : BufferedSafeStringBase<T>(mBuffer, L)
317 {
318 this->copy(str);
319 }
320
321 FixedSafeStringBase(const FixedSafeStringBase& str) : BufferedSafeStringBase<T>(mBuffer, L)
322 {
323 this->copy(str);
324 }
325
326 ~FixedSafeStringBase() override = default;
327
328 FixedSafeStringBase& operator=(const FixedSafeStringBase& other)
329 {
330 this->copy(other);
331 return *this;
332 }
333
334 FixedSafeStringBase& operator=(const SafeStringBase<T>& other) override
335 {
336 this->copy(other);
337 return *this;
338 }
339
340 T mBuffer[L];
341};
342
343typedef SafeStringBase<char> SafeString;
344typedef SafeStringBase<char16> WSafeString;
345typedef BufferedSafeStringBase<char> BufferedSafeString;
346typedef BufferedSafeStringBase<char16> WBufferedSafeString;
347
348template <>
349s32 BufferedSafeStringBase<char>::format(const char* formatStr, ...);
350template <>
351s32 BufferedSafeStringBase<char16>::format(const char16* formatStr, ...);
352template <>
353s32 BufferedSafeStringBase<char>::formatV(const char* formatStr, va_list args);
354template <>
355s32 BufferedSafeStringBase<char16>::formatV(const char16* formatStr, va_list args);
356template <>
357s32 BufferedSafeStringBase<char>::appendWithFormat(const char* formatStr, ...);
358template <>
359s32 BufferedSafeStringBase<char16>::appendWithFormat(const char16* formatStr, ...);
360template <>
361s32 BufferedSafeStringBase<char>::appendWithFormatV(const char* formatStr, va_list args);
362template <>
363s32 BufferedSafeStringBase<char16>::appendWithFormatV(const char16* formatStr, va_list args);
364
365template <s32 L>
366class FixedSafeString : public FixedSafeStringBase<char, L>
367{
368public:
369 FixedSafeString() : FixedSafeStringBase<char, L>() {}
370 FixedSafeString(const SafeString& str) : FixedSafeStringBase<char, L>(str) {}
371 FixedSafeString(const FixedSafeString& other)
372 : FixedSafeString(static_cast<const SafeString&>(other))
373 {
374 }
375
376 FixedSafeString& operator=(const FixedSafeString& other)
377 {
378 this->copy(other);
379 return *this;
380 }
381
382 FixedSafeString<L>& operator=(const SafeStringBase<char>& other) override
383 {
384 this->copy(other);
385 return *this;
386 }
387};
388
389template <s32 L>
390class WFixedSafeString : public FixedSafeStringBase<char16, L>
391{
392public:
393 WFixedSafeString() : FixedSafeStringBase<char16, L>() {}
394
395 WFixedSafeString(const WSafeString& str) : FixedSafeStringBase<char16, L>(str) {}
396};
397
398template <s32 L>
399class FormatFixedSafeString : public FixedSafeString<L>
400{
401public:
402 FormatFixedSafeString() : FormatFixedSafeString("") {}
403
404#ifdef __GNUC__
405 [[gnu::format(printf, 2, 3)]]
406#endif
407 explicit FormatFixedSafeString(const char* format, ...)
408 : FixedSafeString<L>()
409 {
410 std::va_list args;
411 va_start(args, format);
412 this->formatV(format, args);
413 va_end(args);
414 }
415 ~FormatFixedSafeString() override = default;
416};
417
418template <s32 L>
419class WFormatFixedSafeString : public WFixedSafeString<L>
420{
421public:
422 explicit WFormatFixedSafeString(const char16* format, ...)
423 {
424 std::va_list args;
425 va_start(args, format);
426 this->formatV(format, args);
427 va_end(args);
428 }
429 ~WFormatFixedSafeString() override = default;
430};
431
432template <typename T>
433class HeapSafeStringBase : public BufferedSafeStringBase<T>
434{
435public:
436 HeapSafeStringBase(Heap* heap, const SafeStringBase<T>& string, s32 alignment = sizeof(void*))
437 : BufferedSafeStringBase<T>(new (heap, alignment) T[string.calcLength() + 1](),
438 string.calcLength() + 1)
439 {
440 this->copy(string);
441 }
442
443 HeapSafeStringBase(const HeapSafeStringBase&) = delete;
444 HeapSafeStringBase& operator=(const HeapSafeStringBase&) = delete;
445
446 HeapSafeStringBase(HeapSafeStringBase&& other) noexcept
447 {
448 this->mStringTop = other.mStringTop;
449 other.mStringTop = nullptr;
450 }
451 HeapSafeStringBase& operator=(HeapSafeStringBase&& other) noexcept
452 {
453 this->mStringTop = other.mStringTop;
454 other.mStringTop = nullptr;
455 return *this;
456 }
457
458 ~HeapSafeStringBase() override
459 {
460 if (this->mStringTop)
461 delete[] this->mStringTop;
462 }
463
464 HeapSafeStringBase<T>& operator=(const SafeStringBase<T>& other) override;
465};
466
467using HeapSafeString = HeapSafeStringBase<char>;
468using WHeapSafeString = HeapSafeStringBase<char16>;
469
470inline bool operator<(const SafeString& lhs, const SafeString& rhs)
471{
472 return lhs.compare(str: rhs) < 0;
473}
474
475inline bool operator>(const SafeString& lhs, const SafeString& rhs)
476{
477 return lhs.compare(str: rhs) > 0;
478}
479
480inline bool operator<=(const SafeString& lhs, const SafeString& rhs)
481{
482 return lhs.compare(str: rhs) <= 0;
483}
484
485inline bool operator>=(const SafeString& lhs, const SafeString& rhs)
486{
487 return lhs.compare(str: rhs) >= 0;
488}
489
490inline namespace literals
491{
492inline namespace str
493{
494inline SafeString operator""_str(const char* str, std::size_t /*len*/)
495{
496 return str;
497}
498
499inline WSafeString operator""_str(const char16* str, std::size_t /*len*/)
500{
501 return str;
502}
503
504} // namespace str
505} // namespace literals
506
507} // namespace sead
508
509#define SEAD_PRIM_SAFE_STRING_H_
510#include <prim/seadSafeString.hpp>
511#undef SEAD_PRIM_SAFE_STRING_H_
512