1#include <prim/seadSafeString.h>
2#include <prim/seadStringUtil.h>
3
4namespace
5{
6static const char16 cEmptyStringChar16[1] = u"";
7
8} // namespace
9
10namespace sead
11{
12template <>
13const char SafeStringBase<char>::cNullChar = '\0';
14
15template <>
16const char SafeStringBase<char>::cLineBreakChar = '\n';
17
18template <>
19const SafeStringBase<char> SafeStringBase<char>::cEmptyString("");
20
21template <>
22const char16 SafeStringBase<char16>::cNullChar = 0;
23
24template <>
25const char16 SafeStringBase<char16>::cLineBreakChar = static_cast<char16>('\n');
26
27template <>
28const SafeStringBase<char16> SafeStringBase<char16>::cEmptyString(cEmptyStringChar16);
29
30template <>
31SafeStringBase<char>& SafeStringBase<char>::operator=(const SafeStringBase<char>& other) = default;
32
33template <>
34SafeStringBase<char16>&
35SafeStringBase<char16>::operator=(const SafeStringBase<char16>& other) = default;
36
37template <>
38BufferedSafeStringBase<char>&
39BufferedSafeStringBase<char>::operator=(const SafeStringBase<char>& other)
40{
41 copy(src: other);
42 return *this;
43}
44
45template <>
46BufferedSafeStringBase<char16>&
47BufferedSafeStringBase<char16>::operator=(const SafeStringBase<char16>& other)
48{
49 copy(src: other);
50 return *this;
51}
52
53template <>
54HeapSafeStringBase<char>& HeapSafeStringBase<char>::operator=(const SafeStringBase<char>& other)
55{
56 this->copy(src: other);
57 return *this;
58}
59
60template <>
61HeapSafeStringBase<char16>&
62HeapSafeStringBase<char16>::operator=(const SafeStringBase<char16>& other)
63{
64 this->copy(src: other);
65 return *this;
66}
67
68template <>
69void BufferedSafeStringBase<char>::assureTerminationImpl_() const
70{
71 auto* mutableSafeString = const_cast<BufferedSafeStringBase<char>*>(this);
72 mutableSafeString->getMutableStringTop_()[mBufferSize - 1] = cNullChar;
73}
74
75template <>
76void BufferedSafeStringBase<char16>::assureTerminationImpl_() const
77{
78 auto* mutableSafeString = const_cast<BufferedSafeStringBase<char16>*>(this);
79 mutableSafeString->getMutableStringTop_()[mBufferSize - 1] = cNullChar;
80}
81
82template <>
83s32 BufferedSafeStringBase<char>::formatImpl_(char* s, s32 n, const char* formatStr, va_list args)
84{
85 const s32 ret = StringUtil::vsnprintf(s, n, format: formatStr, args);
86 return ret < 0 ? n - 1 : ret;
87}
88
89template <>
90s32 BufferedSafeStringBase<char16>::formatImpl_(char16* s, s32 n, const char16* formatStr,
91 va_list args)
92{
93 const s32 ret = StringUtil::vsnw16printf(s, n, format: formatStr, args);
94 if (ret >= 0 && ret < n)
95 return ret;
96 s[n - 1] = WSafeString::cNullChar;
97 return n - 1;
98}
99
100template <>
101s32 BufferedSafeStringBase<char>::formatV(const char* formatStr, va_list args)
102{
103 char* mutableString = getMutableStringTop_();
104 return formatImpl_(s: mutableString, n: mBufferSize, formatStr, args);
105}
106
107template <>
108s32 BufferedSafeStringBase<char16>::formatV(const char16* formatStr, va_list args)
109{
110 char16* mutableString = getMutableStringTop_();
111 return formatImpl_(s: mutableString, n: mBufferSize, formatStr, args);
112}
113
114template <>
115s32 BufferedSafeStringBase<char>::format(const char* formatStr, ...)
116{
117 va_list args;
118 va_start(args, formatStr);
119 s32 ret = formatV(formatStr, args);
120 va_end(args);
121
122 return ret;
123}
124
125template <>
126s32 BufferedSafeStringBase<char16>::format(const char16* formatStr, ...)
127{
128 va_list args;
129 va_start(args, formatStr);
130 s32 ret = formatV(formatStr, args);
131 va_end(args);
132
133 return ret;
134}
135
136template <>
137s32 BufferedSafeStringBase<char>::appendWithFormatV(const char* format, std::va_list args)
138{
139 char* mutableString = getMutableStringTop_();
140 const s32 len = calcLength();
141 return formatImpl_(s: mutableString + len, n: mBufferSize - len, formatStr: format, args) + len;
142}
143
144template <>
145s32 BufferedSafeStringBase<char16>::appendWithFormatV(const char16* format, std::va_list args)
146{
147 char16* mutableString = getMutableStringTop_();
148 const s32 len = calcLength();
149 return formatImpl_(s: mutableString + len, n: mBufferSize - len, formatStr: format, args) + len;
150}
151
152template <>
153s32 BufferedSafeStringBase<char>::appendWithFormat(const char* format, ...)
154{
155 std::va_list args;
156 va_start(args, format);
157 const s32 ret = appendWithFormatV(format, args);
158 va_end(args);
159 return ret;
160}
161
162template <>
163s32 BufferedSafeStringBase<char16>::appendWithFormat(const char16* format, ...)
164{
165 std::va_list args;
166 va_start(args, format);
167 const s32 ret = appendWithFormatV(format, args);
168 va_end(args);
169 return ret;
170}
171
172// NON_MATCHING
173template <typename T>
174s32 replaceStringImpl_(T* dst, s32* length, s32 dst_size, const T* src, s32 src_size,
175 const SafeStringBase<T>& old_str, const SafeStringBase<T>& new_str,
176 bool* is_buffer_overflow)
177{
178 s32 ret = 0;
179 *is_buffer_overflow = false;
180 const s32 dst_max_idx = dst_size - 1;
181
182 const T* old_cstr = old_str.cstr();
183 const s32 old_str_len = old_str.calcLength();
184
185 if (old_str_len == 0)
186 {
187 if (dst == src)
188 return 0;
189
190 *is_buffer_overflow = src_size >= dst_size;
191 if (src_size >= dst_size)
192 {
193 MemUtil::copy(dest: dst, src, size: dst_max_idx);
194 dst[dst_max_idx] = SafeStringBase<T>::cNullChar;
195 if (length)
196 *length = dst_max_idx;
197 }
198 else
199 {
200 MemUtil::copy(dest: dst, src, size: src_size + 1);
201 if (length)
202 *length = src_size;
203 }
204 return 0;
205 }
206
207 const T* new_cstr = new_str.cstr();
208 const s32 new_str_len = new_str.calcLength();
209
210 // Replace in-place.
211 if (dst == src && old_str_len < new_str_len)
212 {
213 s32 dst_final_size = 0;
214 s32 src_final_size = 0;
215 // First, terminate the string and check for buffer overflow.
216 while (src_final_size < src_size)
217 {
218 const s32 cmp = MemUtil::compare(ptr1: &dst[src_final_size], ptr2: old_cstr, size: old_str_len);
219 src_final_size += cmp == 0 ? old_str_len : 1;
220 dst_final_size += cmp == 0 ? new_str_len : 1;
221 if (dst_final_size >= dst_size)
222 {
223 *is_buffer_overflow = true;
224 break;
225 }
226 }
227
228 if (*is_buffer_overflow)
229 {
230 dst[dst_max_idx] = SafeStringBase<T>::cNullChar;
231 if (length)
232 *length = dst_max_idx;
233 }
234 else
235 {
236 dst[dst_final_size] = SafeStringBase<T>::cNullChar;
237 if (length)
238 *length = dst_final_size;
239 }
240
241 s32 dst_i = dst_final_size - 1;
242 s32 src_i = src_final_size - 1;
243 while (src_i >= 0)
244 {
245 const s32 cmp = MemUtil::compare(ptr1: &dst[src_i + 1 - old_str_len], ptr2: old_cstr, size: old_str_len);
246 if (cmp == 0)
247 {
248 dst_i -= new_str_len;
249 const s32 copy_size = std::min(a: new_str_len, b: dst_size - 2 - dst_i);
250 if (copy_size > 0)
251 {
252 MemUtil::copy(dest: &dst[dst_i + 1], src: new_cstr, size: copy_size);
253 ret += 1;
254 }
255 src_i -= old_str_len;
256 }
257 else
258 {
259 if (dst_i < dst_max_idx)
260 dst[dst_i] = dst[src_i];
261 if (src_i < 1)
262 {
263 --src_i;
264 --dst_i;
265 break;
266 }
267 }
268 }
269
270 SEAD_ASSERT(dst_i == -1);
271 SEAD_ASSERT(src_i == -1);
272 }
273 // Simpler case.
274 else
275 {
276 s32 target_i = 0;
277 s32 buffer_i = 0;
278 while (target_i < src_size)
279 {
280 const s32 cmp = MemUtil::compare(ptr1: &src[target_i], ptr2: old_cstr, size: old_str_len);
281 // Not old_str, copy one character to the buffer.
282 if (cmp != 0)
283 {
284 if (buffer_i < dst_max_idx)
285 {
286 dst[buffer_i++] = src[target_i++];
287 continue;
288 }
289 }
290 // Found old_str, copy new_str to the buffer.
291 else
292 {
293 const s32 copy_size = std::min(a: new_str_len, b: dst_max_idx - buffer_i);
294 if (copy_size >= 1)
295 MemUtil::copy(dest: &dst[buffer_i], src: new_cstr, size: copy_size);
296 ret += new_str_len == 0 || copy_size > 0;
297 if (copy_size >= new_str_len)
298 {
299 buffer_i += new_str_len;
300 target_i += old_str_len;
301 continue;
302 }
303 }
304
305 // Buffer overflow.
306 *is_buffer_overflow = true;
307 dst[dst_max_idx] = SafeStringBase<T>::cNullChar;
308 if (length)
309 *length = dst_max_idx;
310 return ret;
311 }
312
313 SEAD_ASSERT(buffer_i <= dst_size);
314 SEAD_ASSERT(target_i == src_size);
315
316 dst[buffer_i] = SafeStringBase<T>::cNullChar;
317 if (length)
318 *length = buffer_i;
319 }
320
321 return ret;
322}
323
324template s32 replaceStringImpl_<char>(char* buffer, s32* length, s32 buffer_size,
325 const char* target_buf, s32 target_len,
326 const SafeStringBase<char>& old_str,
327 const SafeStringBase<char>& new_str,
328 bool* is_buffer_overflow);
329
330template s32 replaceStringImpl_<char16>(char16* buffer, s32* length, s32 buffer_size,
331 const char16* target_buf, s32 target_len,
332 const SafeStringBase<char16>& old_str,
333 const SafeStringBase<char16>& new_str,
334 bool* is_buffer_overflow);
335
336} // namespace sead
337