Grok 12.0.1
vqsort-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
19
20// unconditional #include so we can use if(VQSORT_PRINT), which unlike #if does
21// not interfere with code-folding.
22#include <stdio.h>
23#include <time.h> // clock
24
25// IWYU pragma: begin_exports
26#include "hwy/base.h"
27#include "hwy/contrib/sort/order.h" // SortAscending
28// IWYU pragma: end_exports
29
30#include "hwy/cache_control.h" // Prefetch
31#include "hwy/print.h" // unconditional, see above.
32
33// If 1, VQSortStatic can be called without including vqsort.h, and we avoid
34// any DLLEXPORT. This simplifies integration into other build systems, but
35// decreases the security of random seeds.
36#ifndef VQSORT_ONLY_STATIC
37#define VQSORT_ONLY_STATIC 0
38#endif
39
40// Verbosity: 0 for none, 1 for brief per-sort, 2+ for more details.
41#ifndef VQSORT_PRINT
42#define VQSORT_PRINT 0
43#endif
44
45#if !VQSORT_ONLY_STATIC
46#include "hwy/contrib/sort/vqsort.h" // Fill16BytesSecure
47#endif
48
49namespace hwy {
50namespace detail {
51
52HWY_INLINE void Fill16BytesStatic(void* bytes) {
53#if !VQSORT_ONLY_STATIC
54 if (Fill16BytesSecure(bytes)) return;
55#endif
56
57 uint64_t* words = reinterpret_cast<uint64_t*>(bytes);
58
59 // Static-only, or Fill16BytesSecure failed. Get some entropy from the
60 // stack/code location, and the clock() timer.
61 uint64_t** seed_stack = &words;
62 void (*seed_code)(void*) = &Fill16BytesStatic;
63 const uintptr_t bits_stack = reinterpret_cast<uintptr_t>(seed_stack);
64 const uintptr_t bits_code = reinterpret_cast<uintptr_t>(seed_code);
65 const uint64_t bits_time = static_cast<uint64_t>(clock());
66 words[0] = bits_stack ^ bits_time ^ 0xFEDCBA98; // "Nothing up my sleeve"
67 words[1] = bits_code ^ bits_time ^ 0x01234567; // constants.
68}
69
71 thread_local uint64_t state[3] = {0};
72 // This is a counter; zero indicates not yet initialized.
73 if (HWY_UNLIKELY(state[2] == 0)) {
74 Fill16BytesStatic(state);
75 state[2] = 1;
76 }
77 return state;
78}
79
80} // namespace detail
81} // namespace hwy
82
83#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
84
85// Per-target
86#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
87 defined(HWY_TARGET_TOGGLE)
88#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
89#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
90#else
91#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
92#endif
93
94#if VQSORT_PRINT
95#include "hwy/print-inl.h"
96#endif
97
103// Placeholder for internal instrumentation. Do not remove.
104#include "hwy/highway.h"
105
107namespace hwy {
108namespace HWY_NAMESPACE {
109namespace detail {
110
112
113// Wrapper avoids #if in user code (interferes with code folding)
114template <class D>
115HWY_INLINE void MaybePrintVector(D d, const char* label, Vec<D> v,
116 size_t start = 0, size_t max_lanes = 16) {
117#if VQSORT_PRINT >= 2 // Print is only defined #if
118 Print(d, label, v, start, max_lanes);
119#else
120 (void)d;
121 (void)label;
122 (void)v;
123 (void)start;
124 (void)max_lanes;
125#endif
126}
127
128// ------------------------------ HeapSort
129
130template <class Traits, typename T>
131void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
132 size_t start) {
133 constexpr size_t N1 = st.LanesPerKey();
134 const FixedTag<T, N1> d;
135
136 while (start < num_lanes) {
137 const size_t left = 2 * start + N1;
138 const size_t right = 2 * start + 2 * N1;
139 if (left >= num_lanes) break;
140 size_t idx_larger = start;
141 const auto key_j = st.SetKey(d, lanes + start);
142 if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) {
143 idx_larger = left;
144 }
145 if (right < num_lanes &&
146 AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger),
147 st.SetKey(d, lanes + right)))) {
148 idx_larger = right;
149 }
150 if (idx_larger == start) break;
151 st.Swap(lanes + start, lanes + idx_larger);
152 start = idx_larger;
153 }
154}
155
156// Heapsort: O(1) space, O(N*logN) worst-case comparisons.
157// Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
158template <class Traits, typename T>
159void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) {
160 constexpr size_t N1 = st.LanesPerKey();
161
162 HWY_ASSERT(num_lanes >= 2 * N1);
163
164 // Build heap.
165 for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
166 SiftDown(st, lanes, num_lanes, i);
167 }
168
169 for (size_t i = num_lanes - N1; i != 0; i -= N1) {
170 // Swap root with last
171 st.Swap(lanes + 0, lanes + i);
172
173 // Sift down the new root.
174 SiftDown(st, lanes, i, 0);
175 }
176}
177
178template <class Traits, typename T>
179void HeapSelect(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
180 const size_t select) {
181 constexpr size_t N1 = st.LanesPerKey();
182 const size_t k = select + 1;
183
184 HWY_ASSERT(k >= 2 * N1 && num_lanes >= 2 * N1);
185
186 const FixedTag<T, N1> d;
187
188 // Build heap.
189 for (size_t i = ((k - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
190 SiftDown(st, lanes, k, i);
191 }
192
193 for (size_t i = k; i <= num_lanes - N1; i += N1) {
194 if (AllTrue(d, st.Compare(d, st.SetKey(d, lanes + i),
195 st.SetKey(d, lanes + 0)))) {
196 // Swap root with last
197 st.Swap(lanes + 0, lanes + i);
198
199 // Sift down the new root.
200 SiftDown(st, lanes, k, 0);
201 }
202 }
203
204 st.Swap(lanes + 0, lanes + k - 1);
205}
206
207template <class Traits, typename T>
208void HeapPartialSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
209 const size_t select) {
210 HeapSelect(st, lanes, num_lanes, select);
211 HeapSort(st, lanes, select);
212}
213
214#if VQSORT_ENABLED || HWY_IDE
215
216// ------------------------------ BaseCase
217
218// Special cases where `num_lanes` is in the specified range (inclusive).
219template <class Traits, typename T>
220HWY_INLINE void Sort2To2(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
221 T* HWY_RESTRICT /* buf */) {
222 constexpr size_t kLPK = st.LanesPerKey();
223 const size_t num_keys = num_lanes / kLPK;
224 HWY_DASSERT(num_keys == 2);
225 HWY_ASSUME(num_keys == 2);
226
227 // One key per vector, required to avoid reading past the end of `keys`.
228 const CappedTag<T, kLPK> d;
229 using V = Vec<decltype(d)>;
230
231 V v0 = LoadU(d, keys + 0x0 * kLPK);
232 V v1 = LoadU(d, keys + 0x1 * kLPK);
233
234 Sort2(d, st, v0, v1);
235
236 StoreU(v0, d, keys + 0x0 * kLPK);
237 StoreU(v1, d, keys + 0x1 * kLPK);
238}
239
240template <class Traits, typename T>
241HWY_INLINE void Sort3To4(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
242 T* HWY_RESTRICT buf) {
243 constexpr size_t kLPK = st.LanesPerKey();
244 const size_t num_keys = num_lanes / kLPK;
245 HWY_DASSERT(3 <= num_keys && num_keys <= 4);
246 HWY_ASSUME(num_keys >= 3);
247 HWY_ASSUME(num_keys <= 4); // reduces branches
248
249 // One key per vector, required to avoid reading past the end of `keys`.
250 const CappedTag<T, kLPK> d;
251 using V = Vec<decltype(d)>;
252
253 // If num_keys == 3, initialize padding for the last sorting network element
254 // so that it does not influence the other elements.
255 Store(st.LastValue(d), d, buf);
256
257 // Points to a valid key, or padding. This avoids special-casing
258 // HWY_MEM_OPS_MIGHT_FAULT because there is only a single key per vector.
259 T* in_out3 = num_keys == 3 ? buf : keys + 0x3 * kLPK;
260
261 V v0 = LoadU(d, keys + 0x0 * kLPK);
262 V v1 = LoadU(d, keys + 0x1 * kLPK);
263 V v2 = LoadU(d, keys + 0x2 * kLPK);
264 V v3 = LoadU(d, in_out3);
265
266 Sort4(d, st, v0, v1, v2, v3);
267
268 StoreU(v0, d, keys + 0x0 * kLPK);
269 StoreU(v1, d, keys + 0x1 * kLPK);
270 StoreU(v2, d, keys + 0x2 * kLPK);
271 StoreU(v3, d, in_out3);
272}
273
274#if HWY_MEM_OPS_MIGHT_FAULT
275
276template <size_t kRows, size_t kLanesPerRow, class D, class Traits,
277 typename T = TFromD<D>>
278HWY_INLINE void CopyHalfToPaddedBuf(D d, Traits st, T* HWY_RESTRICT keys,
279 size_t num_lanes, T* HWY_RESTRICT buf) {
280 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
281 // Must cap for correctness: we will load up to the last valid lane, so
282 // Lanes(dmax) must not exceed `num_lanes` (known to be at least kMinLanes).
283 const CappedTag<T, kMinLanes> dmax;
284 const size_t Nmax = Lanes(dmax);
285 HWY_DASSERT(Nmax < num_lanes);
286 HWY_ASSUME(Nmax <= kMinLanes);
287
288 // Fill with padding - last in sort order, not copied to keys.
289 const Vec<decltype(dmax)> kPadding = st.LastValue(dmax);
290
291 // Rounding down allows aligned stores, which are typically faster.
292 size_t i = num_lanes & ~(Nmax - 1);
293 HWY_ASSUME(i != 0); // because Nmax <= num_lanes; avoids branch
294 do {
295 Store(kPadding, dmax, buf + i);
296 i += Nmax;
297 // Initialize enough for the last vector even if Nmax > kLanesPerRow.
298 } while (i < (kRows - 1) * kLanesPerRow + Lanes(d));
299
300 // Ensure buf contains all we will read, and perhaps more before.
301 ptrdiff_t end = static_cast<ptrdiff_t>(num_lanes);
302 do {
303 end -= static_cast<ptrdiff_t>(Nmax);
304 StoreU(LoadU(dmax, keys + end), dmax, buf + end);
305 } while (end > static_cast<ptrdiff_t>(kRows / 2 * kLanesPerRow));
306}
307
308#endif // HWY_MEM_OPS_MIGHT_FAULT
309
310template <size_t kKeysPerRow, class Traits, typename T>
311HWY_NOINLINE void Sort8Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
312 T* HWY_RESTRICT buf) {
313 // kKeysPerRow <= 4 because 8 64-bit keys implies 512-bit vectors, which
314 // are likely slower than 16x4, so 8x4 is the largest we handle here.
315 static_assert(kKeysPerRow <= 4, "");
316
317 constexpr size_t kLPK = st.LanesPerKey();
318
319 // We reshape the 1D keys into kRows x kKeysPerRow.
320 constexpr size_t kRows = 8;
321 constexpr size_t kLanesPerRow = kKeysPerRow * kLPK;
322 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
323 HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow);
324
325 const CappedTag<T, kLanesPerRow> d;
326 using V = Vec<decltype(d)>;
327 V v4, v5, v6, v7;
328
329 // At least half the kRows are valid, otherwise a different function would
330 // have been called to handle this num_lanes.
331 V v0 = LoadU(d, keys + 0x0 * kLanesPerRow);
332 V v1 = LoadU(d, keys + 0x1 * kLanesPerRow);
333 V v2 = LoadU(d, keys + 0x2 * kLanesPerRow);
334 V v3 = LoadU(d, keys + 0x3 * kLanesPerRow);
335#if HWY_MEM_OPS_MIGHT_FAULT
336 CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf);
337 v4 = LoadU(d, buf + 0x4 * kLanesPerRow);
338 v5 = LoadU(d, buf + 0x5 * kLanesPerRow);
339 v6 = LoadU(d, buf + 0x6 * kLanesPerRow);
340 v7 = LoadU(d, buf + 0x7 * kLanesPerRow);
341#endif // HWY_MEM_OPS_MIGHT_FAULT
342#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
343 (void)buf;
344 const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes));
345 // First offset where not all vector are guaranteed valid.
346 const V kIota = Iota(d, static_cast<T>(kMinLanes));
347 const V k1 = Set(d, static_cast<T>(kLanesPerRow));
348 const V k2 = Add(k1, k1);
349
350 using M = Mask<decltype(d)>;
351 const M m4 = Gt(vnum_lanes, kIota);
352 const M m5 = Gt(vnum_lanes, Add(kIota, k1));
353 const M m6 = Gt(vnum_lanes, Add(kIota, k2));
354 const M m7 = Gt(vnum_lanes, Add(kIota, Add(k2, k1)));
355
356 const V kPadding = st.LastValue(d); // Not copied to keys.
357 v4 = MaskedLoadOr(kPadding, m4, d, keys + 0x4 * kLanesPerRow);
358 v5 = MaskedLoadOr(kPadding, m5, d, keys + 0x5 * kLanesPerRow);
359 v6 = MaskedLoadOr(kPadding, m6, d, keys + 0x6 * kLanesPerRow);
360 v7 = MaskedLoadOr(kPadding, m7, d, keys + 0x7 * kLanesPerRow);
361#endif // !HWY_MEM_OPS_MIGHT_FAULT
362
363 Sort8(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
364
365 // Merge8x2 is a no-op if kKeysPerRow < 2 etc.
366 Merge8x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
367 Merge8x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
368
369 StoreU(v0, d, keys + 0x0 * kLanesPerRow);
370 StoreU(v1, d, keys + 0x1 * kLanesPerRow);
371 StoreU(v2, d, keys + 0x2 * kLanesPerRow);
372 StoreU(v3, d, keys + 0x3 * kLanesPerRow);
373
374#if HWY_MEM_OPS_MIGHT_FAULT
375 // Store remaining vectors into buf and safely copy them into keys.
376 StoreU(v4, d, buf + 0x4 * kLanesPerRow);
377 StoreU(v5, d, buf + 0x5 * kLanesPerRow);
378 StoreU(v6, d, buf + 0x6 * kLanesPerRow);
379 StoreU(v7, d, buf + 0x7 * kLanesPerRow);
380
381 const ScalableTag<T> dmax;
382 const size_t Nmax = Lanes(dmax);
383
384 // The first half of vectors have already been stored unconditionally into
385 // `keys`, so we do not copy them.
386 size_t i = kMinLanes;
387 HWY_UNROLL(1)
388 for (; i + Nmax <= num_lanes; i += Nmax) {
389 StoreU(LoadU(dmax, buf + i), dmax, keys + i);
390 }
391
392 // Last iteration: copy partial vector
393 const size_t remaining = num_lanes - i;
394 HWY_ASSUME(remaining < 256); // helps FirstN
395 SafeCopyN(remaining, dmax, buf + i, keys + i);
396#endif // HWY_MEM_OPS_MIGHT_FAULT
397#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
398 BlendedStore(v4, m4, d, keys + 0x4 * kLanesPerRow);
399 BlendedStore(v5, m5, d, keys + 0x5 * kLanesPerRow);
400 BlendedStore(v6, m6, d, keys + 0x6 * kLanesPerRow);
401 BlendedStore(v7, m7, d, keys + 0x7 * kLanesPerRow);
402#endif // !HWY_MEM_OPS_MIGHT_FAULT
403}
404
405template <size_t kKeysPerRow, class Traits, typename T>
406HWY_NOINLINE void Sort16Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
407 T* HWY_RESTRICT buf) {
408 static_assert(kKeysPerRow <= SortConstants::kMaxCols, "");
409
410 constexpr size_t kLPK = st.LanesPerKey();
411
412 // We reshape the 1D keys into kRows x kKeysPerRow.
413 constexpr size_t kRows = 16;
414 constexpr size_t kLanesPerRow = kKeysPerRow * kLPK;
415 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
416 HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow);
417
418 const CappedTag<T, kLanesPerRow> d;
419 using V = Vec<decltype(d)>;
420 V v8, v9, va, vb, vc, vd, ve, vf;
421
422 // At least half the kRows are valid, otherwise a different function would
423 // have been called to handle this num_lanes.
424 V v0 = LoadU(d, keys + 0x0 * kLanesPerRow);
425 V v1 = LoadU(d, keys + 0x1 * kLanesPerRow);
426 V v2 = LoadU(d, keys + 0x2 * kLanesPerRow);
427 V v3 = LoadU(d, keys + 0x3 * kLanesPerRow);
428 V v4 = LoadU(d, keys + 0x4 * kLanesPerRow);
429 V v5 = LoadU(d, keys + 0x5 * kLanesPerRow);
430 V v6 = LoadU(d, keys + 0x6 * kLanesPerRow);
431 V v7 = LoadU(d, keys + 0x7 * kLanesPerRow);
432#if HWY_MEM_OPS_MIGHT_FAULT
433 CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf);
434 v8 = LoadU(d, buf + 0x8 * kLanesPerRow);
435 v9 = LoadU(d, buf + 0x9 * kLanesPerRow);
436 va = LoadU(d, buf + 0xa * kLanesPerRow);
437 vb = LoadU(d, buf + 0xb * kLanesPerRow);
438 vc = LoadU(d, buf + 0xc * kLanesPerRow);
439 vd = LoadU(d, buf + 0xd * kLanesPerRow);
440 ve = LoadU(d, buf + 0xe * kLanesPerRow);
441 vf = LoadU(d, buf + 0xf * kLanesPerRow);
442#endif // HWY_MEM_OPS_MIGHT_FAULT
443#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
444 (void)buf;
445 const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes));
446 // First offset where not all vector are guaranteed valid.
447 const V kIota = Iota(d, static_cast<T>(kMinLanes));
448 const V k1 = Set(d, static_cast<T>(kLanesPerRow));
449 const V k2 = Add(k1, k1);
450 const V k4 = Add(k2, k2);
451 const V k8 = Add(k4, k4);
452
453 using M = Mask<decltype(d)>;
454 const M m8 = Gt(vnum_lanes, kIota);
455 const M m9 = Gt(vnum_lanes, Add(kIota, k1));
456 const M ma = Gt(vnum_lanes, Add(kIota, k2));
457 const M mb = Gt(vnum_lanes, Add(kIota, Sub(k4, k1)));
458 const M mc = Gt(vnum_lanes, Add(kIota, k4));
459 const M md = Gt(vnum_lanes, Add(kIota, Add(k4, k1)));
460 const M me = Gt(vnum_lanes, Add(kIota, Add(k4, k2)));
461 const M mf = Gt(vnum_lanes, Add(kIota, Sub(k8, k1)));
462
463 const V kPadding = st.LastValue(d); // Not copied to keys.
464 v8 = MaskedLoadOr(kPadding, m8, d, keys + 0x8 * kLanesPerRow);
465 v9 = MaskedLoadOr(kPadding, m9, d, keys + 0x9 * kLanesPerRow);
466 va = MaskedLoadOr(kPadding, ma, d, keys + 0xa * kLanesPerRow);
467 vb = MaskedLoadOr(kPadding, mb, d, keys + 0xb * kLanesPerRow);
468 vc = MaskedLoadOr(kPadding, mc, d, keys + 0xc * kLanesPerRow);
469 vd = MaskedLoadOr(kPadding, md, d, keys + 0xd * kLanesPerRow);
470 ve = MaskedLoadOr(kPadding, me, d, keys + 0xe * kLanesPerRow);
471 vf = MaskedLoadOr(kPadding, mf, d, keys + 0xf * kLanesPerRow);
472#endif // !HWY_MEM_OPS_MIGHT_FAULT
473
474 Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf);
475
476 // Merge16x4 is a no-op if kKeysPerRow < 4 etc.
477 Merge16x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
478 vc, vd, ve, vf);
479 Merge16x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
480 vc, vd, ve, vf);
481 Merge16x8<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
482 vc, vd, ve, vf);
483#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
484 Merge16x16<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
485 vc, vd, ve, vf);
486#endif
487
488 StoreU(v0, d, keys + 0x0 * kLanesPerRow);
489 StoreU(v1, d, keys + 0x1 * kLanesPerRow);
490 StoreU(v2, d, keys + 0x2 * kLanesPerRow);
491 StoreU(v3, d, keys + 0x3 * kLanesPerRow);
492 StoreU(v4, d, keys + 0x4 * kLanesPerRow);
493 StoreU(v5, d, keys + 0x5 * kLanesPerRow);
494 StoreU(v6, d, keys + 0x6 * kLanesPerRow);
495 StoreU(v7, d, keys + 0x7 * kLanesPerRow);
496
497#if HWY_MEM_OPS_MIGHT_FAULT
498 // Store remaining vectors into buf and safely copy them into keys.
499 StoreU(v8, d, buf + 0x8 * kLanesPerRow);
500 StoreU(v9, d, buf + 0x9 * kLanesPerRow);
501 StoreU(va, d, buf + 0xa * kLanesPerRow);
502 StoreU(vb, d, buf + 0xb * kLanesPerRow);
503 StoreU(vc, d, buf + 0xc * kLanesPerRow);
504 StoreU(vd, d, buf + 0xd * kLanesPerRow);
505 StoreU(ve, d, buf + 0xe * kLanesPerRow);
506 StoreU(vf, d, buf + 0xf * kLanesPerRow);
507
508 const ScalableTag<T> dmax;
509 const size_t Nmax = Lanes(dmax);
510
511 // The first half of vectors have already been stored unconditionally into
512 // `keys`, so we do not copy them.
513 size_t i = kMinLanes;
514 HWY_UNROLL(1)
515 for (; i + Nmax <= num_lanes; i += Nmax) {
516 StoreU(LoadU(dmax, buf + i), dmax, keys + i);
517 }
518
519 // Last iteration: copy partial vector
520 const size_t remaining = num_lanes - i;
521 HWY_ASSUME(remaining < 256); // helps FirstN
522 SafeCopyN(remaining, dmax, buf + i, keys + i);
523#endif // HWY_MEM_OPS_MIGHT_FAULT
524#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
525 BlendedStore(v8, m8, d, keys + 0x8 * kLanesPerRow);
526 BlendedStore(v9, m9, d, keys + 0x9 * kLanesPerRow);
527 BlendedStore(va, ma, d, keys + 0xa * kLanesPerRow);
528 BlendedStore(vb, mb, d, keys + 0xb * kLanesPerRow);
529 BlendedStore(vc, mc, d, keys + 0xc * kLanesPerRow);
530 BlendedStore(vd, md, d, keys + 0xd * kLanesPerRow);
531 BlendedStore(ve, me, d, keys + 0xe * kLanesPerRow);
532 BlendedStore(vf, mf, d, keys + 0xf * kLanesPerRow);
533#endif // !HWY_MEM_OPS_MIGHT_FAULT
534}
535
536// Sorts `keys` within the range [0, num_lanes) via sorting network.
537// Reshapes into a matrix, sorts columns independently, and then merges
538// into a sorted 1D array without transposing.
539//
540// `TraitsKV` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
541// differences in sort order and single-lane vs 128-bit keys. For key-value
542// types, items with the same key are not equivalent. Our sorting network
543// does not preserve order, thus we prevent mixing padding into the items by
544// comparing all the item bits, including the value (see *ForSortingNetwork).
545//
546// See M. Blacher's thesis: https://github.com/mark-blacher/masterthesis
547template <class D, class TraitsKV, typename T>
548HWY_NOINLINE void BaseCase(D d, TraitsKV, T* HWY_RESTRICT keys,
549 size_t num_lanes, T* buf) {
550 using Traits = typename TraitsKV::SharedTraitsForSortingNetwork;
551 Traits st;
552 constexpr size_t kLPK = st.LanesPerKey();
553 HWY_DASSERT(num_lanes <= Constants::BaseCaseNumLanes<kLPK>(Lanes(d)));
554 const size_t num_keys = num_lanes / kLPK;
555
556 // Can be zero when called through HandleSpecialCases, but also 1 (in which
557 // case the array is already sorted). Also ensures num_lanes - 1 != 0.
558 if (HWY_UNLIKELY(num_keys <= 1)) return;
559
560 const size_t ceil_log2 =
561 32 - Num0BitsAboveMS1Bit_Nonzero32(static_cast<uint32_t>(num_keys - 1));
562
563 // Checking kMaxKeysPerVector avoids generating unreachable codepaths.
564 constexpr size_t kMaxKeysPerVector = MaxLanes(d) / kLPK;
565
566 using FuncPtr = decltype(&Sort2To2<Traits, T>);
567 const FuncPtr funcs[9] = {
568 /* <= 1 */ nullptr, // We ensured num_keys > 1.
569 /* <= 2 */ &Sort2To2<Traits, T>,
570 /* <= 4 */ &Sort3To4<Traits, T>,
571 /* <= 8 */ &Sort8Rows<1, Traits, T>, // 1 key per row
572 /* <= 16 */ kMaxKeysPerVector >= 2 ? &Sort8Rows<2, Traits, T> : nullptr,
573 /* <= 32 */ kMaxKeysPerVector >= 4 ? &Sort8Rows<4, Traits, T> : nullptr,
574 /* <= 64 */ kMaxKeysPerVector >= 4 ? &Sort16Rows<4, Traits, T> : nullptr,
575 /* <= 128 */ kMaxKeysPerVector >= 8 ? &Sort16Rows<8, Traits, T> : nullptr,
576#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
577 /* <= 256 */ kMaxKeysPerVector >= 16 ? &Sort16Rows<16, Traits, T> : nullptr,
578#endif
579 };
580 funcs[ceil_log2](st, keys, num_lanes, buf);
581}
582
583// ------------------------------ Partition
584
585// Partitions O(1) of the *rightmost* keys, at least `N`, until a multiple of
586// kUnroll*N remains, or all keys if there are too few for that.
587//
588// Returns how many remain to partition at the *start* of `keys`, sets `bufL` to
589// the number of keys for the left partition written to `buf`, and `writeR` to
590// the start of the finished right partition at the end of `keys`.
591template <class D, class Traits, class T>
592HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys,
593 const size_t num, const Vec<D> pivot,
594 size_t& bufL, size_t& writeR,
595 T* HWY_RESTRICT buf) {
596 const size_t N = Lanes(d);
597 HWY_DASSERT(num > 2 * N); // BaseCase handles smaller arrays
598
599 constexpr size_t kUnroll = Constants::kPartitionUnroll;
600 size_t num_here; // how many to process here
601 size_t num_main; // how many for main Partition loop (return value)
602 {
603 // The main Partition loop increments by kUnroll * N, so at least handle
604 // the remainders here.
605 const size_t remainder = num & (kUnroll * N - 1);
606 // Ensure we handle at least one vector to prevent overruns (see below), but
607 // still leave a multiple of kUnroll * N.
608 const size_t min = remainder + (remainder < N ? kUnroll * N : 0);
609 // Do not exceed the input size.
610 num_here = HWY_MIN(min, num);
611 num_main = num - num_here;
612 // Before the main Partition loop we load two blocks; if not enough left for
613 // that, handle everything here.
614 if (num_main < 2 * kUnroll * N) {
615 num_here = num;
616 num_main = 0;
617 }
618 }
619
620 // Note that `StoreLeftRight` uses `CompressBlendedStore`, which may load and
621 // store a whole vector starting at `writeR`, and thus overrun `keys`. To
622 // prevent this, we partition at least `N` of the rightmost `keys` so that
623 // `StoreLeftRight` will be able to safely blend into them.
624 HWY_DASSERT(num_here >= N);
625
626 // We cannot use `CompressBlendedStore` for the same reason, so we instead
627 // write the right-of-partition keys into a buffer in ascending order.
628 // `min` may be up to (kUnroll + 1) * N, hence `num_here` could be as much as
629 // (3 * kUnroll + 1) * N, and they might all fall on one side of the pivot.
630 const size_t max_buf = (3 * kUnroll + 1) * N;
631 HWY_DASSERT(num_here <= max_buf);
632
633 const T* pReadR = keys + num; // pre-decremented by N
634
635 bufL = 0;
636 size_t bufR = max_buf; // starting position, not the actual count.
637
638 size_t i = 0;
639 // For whole vectors, we can LoadU.
640 for (; i <= num_here - N; i += N) {
641 pReadR -= N;
642 HWY_DASSERT(pReadR >= keys);
643 const Vec<D> v = LoadU(d, pReadR);
644
645 const Mask<D> comp = st.Compare(d, pivot, v);
646 const size_t numL = CompressStore(v, Not(comp), d, buf + bufL);
647 bufL += numL;
648 (void)CompressStore(v, comp, d, buf + bufR);
649 bufR += (N - numL);
650 }
651
652 // Last iteration: avoid reading past the end.
653 const size_t remaining = num_here - i;
654 if (HWY_LIKELY(remaining != 0)) {
655 const Mask<D> mask = FirstN(d, remaining);
656 pReadR -= remaining;
657 HWY_DASSERT(pReadR >= keys);
658 const Vec<D> v = LoadN(d, pReadR, remaining);
659
660 const Mask<D> comp = st.Compare(d, pivot, v);
661 const size_t numL = CompressStore(v, AndNot(comp, mask), d, buf + bufL);
662 bufL += numL;
663 (void)CompressStore(v, comp, d, buf + bufR);
664 bufR += (remaining - numL);
665 }
666
667 const size_t numWrittenR = bufR - max_buf;
668 // MSan seems not to understand CompressStore.
669 detail::MaybeUnpoison(buf, bufL);
670 detail::MaybeUnpoison(buf + max_buf, numWrittenR);
671
672 // Overwrite already-read end of keys with bufR.
673 writeR = num - numWrittenR;
674 hwy::CopyBytes(buf + max_buf, keys + writeR, numWrittenR * sizeof(T));
675 // Ensure we finished reading/writing all we wanted
676 HWY_DASSERT(pReadR == keys + num_main);
677 HWY_DASSERT(bufL + numWrittenR == num_here);
678 return num_main;
679}
680
681// Note: we could track the OrXor of v and pivot to see if the entire left
682// partition is equal, but that happens rarely and thus is a net loss.
683template <class D, class Traits, typename T>
684HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
685 const Vec<D> pivot, T* HWY_RESTRICT keys,
686 size_t& writeL, size_t& remaining) {
687 const size_t N = Lanes(d);
688
689 const Mask<D> comp = st.Compare(d, pivot, v);
690
691 // Otherwise StoreU/CompressStore overwrites right keys.
692 HWY_DASSERT(remaining >= 2 * N);
693
694 remaining -= N;
696 (HWY_MAX_BYTES == 16 && st.Is128())) {
697 // Non-native Compress (e.g. AVX2): we are able to partition a vector using
698 // a single Compress+two StoreU instead of two Compress[Blended]Store. The
699 // latter are more expensive. Because we store entire vectors, the contents
700 // between the updated writeL and writeR are ignored and will be overwritten
701 // by subsequent calls. This works because writeL and writeR are at least
702 // two vectors apart.
703 const Vec<D> lr = st.CompressKeys(v, comp);
704 const size_t num_left = N - CountTrue(d, comp);
705 StoreU(lr, d, keys + writeL);
706 // Now write the right-side elements (if any), such that the previous writeR
707 // is one past the end of the newly written right elements, then advance.
708 StoreU(lr, d, keys + remaining + writeL);
709 writeL += num_left;
710 } else {
711 // Native Compress[Store] (e.g. AVX3), which only keep the left or right
712 // side, not both, hence we require two calls.
713 const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
714 writeL += num_left;
715
716 (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL);
717 }
718}
719
720template <class D, class Traits, typename T>
721HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
722 const Vec<D> v1, const Vec<D> v2,
723 const Vec<D> v3, const Vec<D> pivot,
724 T* HWY_RESTRICT keys, size_t& writeL,
725 size_t& remaining) {
726 StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining);
727 StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining);
728 StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining);
729 StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining);
730}
731
732// For the last two vectors, we cannot use StoreLeftRight because it might
733// overwrite prior right-side keys. Instead write R and append L into `buf`.
734template <class D, class Traits, typename T>
735HWY_INLINE void StoreRightAndBuf(D d, Traits st, const Vec<D> v,
736 const Vec<D> pivot, T* HWY_RESTRICT keys,
737 size_t& writeR, T* HWY_RESTRICT buf,
738 size_t& bufL) {
739 const size_t N = Lanes(d);
740 const Mask<D> comp = st.Compare(d, pivot, v);
741 const size_t numL = CompressStore(v, Not(comp), d, buf + bufL);
742 bufL += numL;
743 writeR -= (N - numL);
744 (void)CompressBlendedStore(v, comp, d, keys + writeR);
745}
746
747// Moves "<= pivot" keys to the front, and others to the back. pivot is
748// broadcasted. Returns the index of the first key in the right partition.
749//
750// Time-critical, but aligned loads do not seem to be worthwhile because we
751// are not bottlenecked by load ports.
752template <class D, class Traits, typename T>
753HWY_INLINE size_t Partition(D d, Traits st, T* const keys, const size_t num,
754 const Vec<D> pivot, T* HWY_RESTRICT buf) {
755 using V = decltype(Zero(d));
756 const size_t N = Lanes(d);
757
758 size_t bufL, writeR;
759 const size_t num_main =
760 PartitionRightmost(d, st, keys, num, pivot, bufL, writeR, buf);
761 HWY_DASSERT(num_main <= num && writeR <= num);
763 HWY_DASSERT(num_main + bufL == writeR);
764
765 if (VQSORT_PRINT >= 3) {
766 fprintf(stderr, " num_main %zu bufL %zu writeR %zu\n", num_main, bufL,
767 writeR);
768 }
769
770 constexpr size_t kUnroll = Constants::kPartitionUnroll;
771
772 // Partition splits the vector into 3 sections, left to right: Elements
773 // smaller or equal to the pivot, unpartitioned elements and elements larger
774 // than the pivot. To write elements unconditionally on the loop body without
775 // overwriting existing data, we maintain two regions of the loop where all
776 // elements have been copied elsewhere (e.g. vector registers.). I call these
777 // bufferL and bufferR, for left and right respectively.
778 //
779 // These regions are tracked by the indices (writeL, writeR, left, right) as
780 // presented in the diagram below.
781 //
782 // writeL writeR
783 // \/ \/
784 // | <= pivot | bufferL | unpartitioned | bufferR | > pivot |
785 // \/ \/ \/
786 // readL readR num
787 //
788 // In the main loop body below we choose a side, load some elements out of the
789 // vector and move either `readL` or `readR`. Next we call into StoreLeftRight
790 // to partition the data, and the partitioned elements will be written either
791 // to writeR or writeL and the corresponding index will be moved accordingly.
792 //
793 // Note that writeR is not explicitly tracked as an optimization for platforms
794 // with conditional operations. Instead we track writeL and the number of
795 // not yet written elements (`remaining`). From the diagram above we can see
796 // that:
797 // writeR - writeL = remaining => writeR = remaining + writeL
798 //
799 // Tracking `remaining` is advantageous because each iteration reduces the
800 // number of unpartitioned elements by a fixed amount, so we can compute
801 // `remaining` without data dependencies.
802 size_t writeL = 0;
803 size_t remaining = writeR - writeL;
804
805 const T* readL = keys;
806 const T* readR = keys + num_main;
807 // Cannot load if there were fewer than 2 * kUnroll * N.
808 if (HWY_LIKELY(num_main != 0)) {
809 HWY_DASSERT(num_main >= 2 * kUnroll * N);
810 HWY_DASSERT((num_main & (kUnroll * N - 1)) == 0);
811
812 // Make space for writing in-place by reading from readL/readR.
813 const V vL0 = LoadU(d, readL + 0 * N);
814 const V vL1 = LoadU(d, readL + 1 * N);
815 const V vL2 = LoadU(d, readL + 2 * N);
816 const V vL3 = LoadU(d, readL + 3 * N);
817 readL += kUnroll * N;
818 readR -= kUnroll * N;
819 const V vR0 = LoadU(d, readR + 0 * N);
820 const V vR1 = LoadU(d, readR + 1 * N);
821 const V vR2 = LoadU(d, readR + 2 * N);
822 const V vR3 = LoadU(d, readR + 3 * N);
823
824 // readL/readR changed above, so check again before the loop.
825 while (readL != readR) {
826 V v0, v1, v2, v3;
827
828 // Data-dependent but branching is faster than forcing branch-free.
829 const size_t capacityL =
830 static_cast<size_t>((readL - keys) - static_cast<ptrdiff_t>(writeL));
831 HWY_DASSERT(capacityL <= num_main); // >= 0
832 // Load data from the end of the vector with less data (front or back).
833 // The next paragraphs explain how this works.
834 //
835 // let block_size = (kUnroll * N)
836 // On the loop prelude we load block_size elements from the front of the
837 // vector and an additional block_size elements from the back. On each
838 // iteration k elements are written to the front of the vector and
839 // (block_size - k) to the back.
840 //
841 // This creates a loop invariant where the capacity on the front
842 // (capacityL) and on the back (capacityR) always add to 2 * block_size.
843 // In other words:
844 // capacityL + capacityR = 2 * block_size
845 // capacityR = 2 * block_size - capacityL
846 //
847 // This means that:
848 // capacityL > capacityR <=>
849 // capacityL > 2 * block_size - capacityL <=>
850 // 2 * capacityL > 2 * block_size <=>
851 // capacityL > block_size
852 if (capacityL > kUnroll * N) { // equivalent to capacityL > capacityR.
853 readR -= kUnroll * N;
854 v0 = LoadU(d, readR + 0 * N);
855 v1 = LoadU(d, readR + 1 * N);
856 v2 = LoadU(d, readR + 2 * N);
857 v3 = LoadU(d, readR + 3 * N);
858 hwy::Prefetch(readR - 3 * kUnroll * N);
859 } else {
860 v0 = LoadU(d, readL + 0 * N);
861 v1 = LoadU(d, readL + 1 * N);
862 v2 = LoadU(d, readL + 2 * N);
863 v3 = LoadU(d, readL + 3 * N);
864 readL += kUnroll * N;
865 hwy::Prefetch(readL + 3 * kUnroll * N);
866 }
867
868 StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining);
869 }
870
871 // Now finish writing the saved vectors to the middle.
872 StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining);
873
874 StoreLeftRight(d, st, vR0, pivot, keys, writeL, remaining);
875 StoreLeftRight(d, st, vR1, pivot, keys, writeL, remaining);
876
877 // Switch back to updating writeR for clarity. The middle is missing vR2/3
878 // and what is in the buffer.
879 HWY_DASSERT(remaining == bufL + 2 * N);
880 writeR = writeL + remaining;
881 // Switch to StoreRightAndBuf for the last two vectors because
882 // StoreLeftRight may overwrite prior keys.
883 StoreRightAndBuf(d, st, vR2, pivot, keys, writeR, buf, bufL);
884 StoreRightAndBuf(d, st, vR3, pivot, keys, writeR, buf, bufL);
885 HWY_DASSERT(writeR <= num); // >= 0
887 }
888
889 // We have partitioned [0, num) into [0, writeL) and [writeR, num).
890 // Now insert left keys from `buf` to empty space starting at writeL.
891 HWY_DASSERT(writeL + bufL == writeR);
892 CopyBytes(buf, keys + writeL, bufL * sizeof(T));
893
894 return writeL + bufL;
895}
896
897// Returns true and partitions if [keys, keys + num) contains only {valueL,
898// valueR}. Otherwise, sets third to the first differing value; keys may have
899// been reordered and a regular Partition is still necessary.
900// Called from two locations, hence NOINLINE.
901template <class D, class Traits, typename T>
902HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys,
903 size_t num, const Vec<D> valueL,
904 const Vec<D> valueR, Vec<D>& third,
905 T* HWY_RESTRICT buf) {
906 const size_t N = Lanes(d);
907 // No guarantee that num >= N because this is called for subarrays!
908
909 size_t i = 0;
910 size_t writeL = 0;
911
912 // As long as all lanes are equal to L or R, we can overwrite with valueL.
913 // This is faster than first counting, then backtracking to fill L and R.
914 if (num >= N) {
915 for (; i <= num - N; i += N) {
916 const Vec<D> v = LoadU(d, keys + i);
917 // It is not clear how to apply OrXor here - that can check if *both*
918 // comparisons are true, but here we want *either*. Comparing the unsigned
919 // min of differences to zero works, but is expensive for u64 prior to
920 // AVX3.
921 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
922 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
923 // At least one other value present; will require a regular partition.
924 // On AVX-512, Or + AllTrue are folded into a single kortest if we are
925 // careful with the FindKnownFirstTrue argument, see below.
926 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
927 // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the
928 // loop, which is a pessimization because this if-true branch is cold.
929 // We can defeat this via Not(Xor), which is equivalent because eqL and
930 // eqR cannot be true at the same time. Can we elide the additional Not?
931 // FindFirstFalse instructions are generally unavailable, but we can
932 // fuse Not and Xor/Or into one ExclusiveNeither.
933 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
934 third = st.SetKey(d, keys + i + lane);
935 if (VQSORT_PRINT >= 2) {
936 fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i,
937 writeL);
938 }
939 // 'Undo' what we did by filling the remainder of what we read with R.
940 if (i >= N) {
941 for (; writeL <= i - N; writeL += N) {
942 StoreU(valueR, d, keys + writeL);
943 }
944 }
945 BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL);
946 return false;
947 }
948 StoreU(valueL, d, keys + writeL);
949 writeL += CountTrue(d, eqL);
950 }
951 }
952
953 // Final vector, masked comparison (no effect if i == num)
954 const size_t remaining = num - i;
955 SafeCopyN(remaining, d, keys + i, buf);
956 const Vec<D> v = Load(d, buf);
957 const Mask<D> valid = FirstN(d, remaining);
958 const Mask<D> eqL = And(st.EqualKeys(d, v, valueL), valid);
959 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
960 // Invalid lanes are considered equal.
961 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
962 // At least one other value present; will require a regular partition.
963 if (HWY_UNLIKELY(!AllTrue(d, eq))) {
964 const size_t lane = FindKnownFirstTrue(d, Not(eq));
965 third = st.SetKey(d, keys + i + lane);
966 if (VQSORT_PRINT >= 2) {
967 fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i,
968 writeL);
969 }
970 // 'Undo' what we did by filling the remainder of what we read with R.
971 if (i >= N) {
972 for (; writeL <= i - N; writeL += N) {
973 StoreU(valueR, d, keys + writeL);
974 }
975 }
976 BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL);
977 return false;
978 }
979 BlendedStore(valueL, valid, d, keys + writeL);
980 writeL += CountTrue(d, eqL);
981
982 // Fill right side
983 i = writeL;
984 if (num >= N) {
985 for (; i <= num - N; i += N) {
986 StoreU(valueR, d, keys + i);
987 }
988 }
989 BlendedStore(valueR, FirstN(d, num - i), d, keys + i);
990
991 if (VQSORT_PRINT >= 2) {
992 fprintf(stderr, "Successful MaybePartitionTwoValue\n");
993 }
994 return true;
995}
996
997// Same as above, except that the pivot equals valueR, so scan right to left.
998template <class D, class Traits, typename T>
999HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys,
1000 size_t num, const Vec<D> valueL,
1001 const Vec<D> valueR, Vec<D>& third,
1002 T* HWY_RESTRICT buf) {
1003 const size_t N = Lanes(d);
1004
1005 HWY_DASSERT(num >= N);
1006 size_t pos = num - N; // current read/write position
1007 size_t countR = 0; // number of valueR found
1008
1009 // For whole vectors, in descending address order: as long as all lanes are
1010 // equal to L or R, overwrite with valueR. This is faster than counting, then
1011 // filling both L and R. Loop terminates after unsigned wraparound.
1012 for (; pos < num; pos -= N) {
1013 const Vec<D> v = LoadU(d, keys + pos);
1014 // It is not clear how to apply OrXor here - that can check if *both*
1015 // comparisons are true, but here we want *either*. Comparing the unsigned
1016 // min of differences to zero works, but is expensive for u64 prior to AVX3.
1017 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
1018 const Mask<D> eqR = st.EqualKeys(d, v, valueR);
1019 // If there is a third value, stop and undo what we've done. On AVX-512,
1020 // Or + AllTrue are folded into a single kortest, but only if we are
1021 // careful with the FindKnownFirstTrue argument - see prior comment on that.
1022 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
1023 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
1024 third = st.SetKey(d, keys + pos + lane);
1025 if (VQSORT_PRINT >= 2) {
1026 fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos,
1027 countR);
1028 MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
1029 }
1030 pos += N; // rewind: we haven't yet committed changes in this iteration.
1031 // We have filled [pos, num) with R, but only countR of them should have
1032 // been written. Rewrite [pos, num - countR) to L.
1033 HWY_DASSERT(countR <= num - pos);
1034 const size_t endL = num - countR;
1035 if (endL >= N) {
1036 for (; pos <= endL - N; pos += N) {
1037 StoreU(valueL, d, keys + pos);
1038 }
1039 }
1040 BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos);
1041 return false;
1042 }
1043 StoreU(valueR, d, keys + pos);
1044 countR += CountTrue(d, eqR);
1045 }
1046
1047 // Final partial (or empty) vector, masked comparison.
1048 const size_t remaining = pos + N;
1049 HWY_DASSERT(remaining <= N);
1050 const Vec<D> v = LoadU(d, keys); // Safe because num >= N.
1051 const Mask<D> valid = FirstN(d, remaining);
1052 const Mask<D> eqL = st.EqualKeys(d, v, valueL);
1053 const Mask<D> eqR = And(st.EqualKeys(d, v, valueR), valid);
1054 // Invalid lanes are considered equal.
1055 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
1056 // At least one other value present; will require a regular partition.
1057 if (HWY_UNLIKELY(!AllTrue(d, eq))) {
1058 const size_t lane = FindKnownFirstTrue(d, Not(eq));
1059 third = st.SetKey(d, keys + lane);
1060 if (VQSORT_PRINT >= 2) {
1061 fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos,
1062 countR);
1063 MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
1064 }
1065 pos += N; // rewind: we haven't yet committed changes in this iteration.
1066 // We have filled [pos, num) with R, but only countR of them should have
1067 // been written. Rewrite [pos, num - countR) to L.
1068 HWY_DASSERT(countR <= num - pos);
1069 const size_t endL = num - countR;
1070 if (endL >= N) {
1071 for (; pos <= endL - N; pos += N) {
1072 StoreU(valueL, d, keys + pos);
1073 }
1074 }
1075 BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos);
1076 return false;
1077 }
1078 const size_t lastR = CountTrue(d, eqR);
1079 countR += lastR;
1080
1081 // First finish writing valueR - [0, N) lanes were not yet written.
1082 StoreU(valueR, d, keys); // Safe because num >= N.
1083
1084 // Fill left side (ascending order for clarity)
1085 const size_t endL = num - countR;
1086 size_t i = 0;
1087 if (endL >= N) {
1088 for (; i <= endL - N; i += N) {
1089 StoreU(valueL, d, keys + i);
1090 }
1091 }
1092 Store(valueL, d, buf);
1093 SafeCopyN(endL - i, d, buf, keys + i); // avoids ASan overrun
1094
1095 if (VQSORT_PRINT >= 2) {
1096 fprintf(stderr,
1097 "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n",
1098 countR, pos, i, endL);
1099 }
1100
1101 return true;
1102}
1103
1104// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the
1105// second key. This is the first path into `MaybePartitionTwoValue`, called
1106// when all samples are equal. Returns false if there are at least a third
1107// value and sets `third`. Otherwise, partitions the array and returns true.
1108template <class D, class Traits, typename T>
1109HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec<D> pivot,
1110 T* HWY_RESTRICT keys, size_t num,
1111 const size_t idx_second, const Vec<D> second,
1112 Vec<D>& third, T* HWY_RESTRICT buf) {
1113 // True if second comes before pivot.
1114 const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second));
1115 if (VQSORT_PRINT >= 1) {
1116 fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second,
1117 is_pivotR);
1118 }
1119 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot)));
1120
1121 // If pivot is R, we scan backwards over the entire array. Otherwise,
1122 // we already scanned up to idx_second and can leave those in place.
1123 return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot,
1124 third, buf)
1125 : MaybePartitionTwoValue(d, st, keys + idx_second,
1126 num - idx_second, pivot, second,
1127 third, buf);
1128}
1129
1130// Second path into `MaybePartitionTwoValue`, called when not all samples are
1131// equal. `samples` is sorted.
1132template <class D, class Traits, typename T>
1133HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys,
1134 size_t num, T* HWY_RESTRICT samples) {
1135 constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
1136 constexpr size_t N1 = st.LanesPerKey();
1137 const Vec<D> valueL = st.SetKey(d, samples);
1138 const Vec<D> valueR = st.SetKey(d, samples + kSampleLanes - N1);
1139 HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR)));
1140 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR)));
1141 const Vec<D> prev = st.PrevValue(d, valueR);
1142 // If the sample has more than two values, then the keys have at least that
1143 // many, and thus this special case is inapplicable.
1144 if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) {
1145 return false;
1146 }
1147
1148 // Must not overwrite samples because if this returns false, caller wants to
1149 // read the original samples again.
1150 T* HWY_RESTRICT buf = samples + kSampleLanes;
1151 Vec<D> third; // unused
1152 return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf);
1153}
1154
1155// ------------------------------ Pivot sampling
1156
1157template <class Traits, class V>
1158HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
1159 const DFromV<V> d;
1160 // Slightly faster for 128-bit, apparently because not serially dependent.
1161 if (st.Is128()) {
1162 // Median = XOR-sum 'minus' the first and last. Calling First twice is
1163 // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
1164 const V sum = Xor(Xor(v0, v1), v2);
1165 const V first = st.First(d, st.First(d, v0, v1), v2);
1166 const V last = st.Last(d, st.Last(d, v0, v1), v2);
1167 return Xor(Xor(sum, first), last);
1168 }
1169 st.Sort2(d, v0, v2);
1170 v1 = st.Last(d, v0, v1);
1171 v1 = st.First(d, v1, v2);
1172 return v1;
1173}
1174
1175// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
1176HWY_INLINE uint64_t RandomBits(uint64_t* HWY_RESTRICT state) {
1177 const uint64_t a = state[0];
1178 const uint64_t b = state[1];
1179 const uint64_t w = state[2] + 1;
1180 const uint64_t next = a ^ w;
1181 state[0] = (b + (b << 3)) ^ (b >> 11);
1182 const uint64_t rot = (b << 24) | (b >> 40);
1183 state[1] = rot + next;
1184 state[2] = w;
1185 return next;
1186}
1187
1188// Returns slightly biased random index of a chunk in [0, num_chunks).
1189// See https://www.pcg-random.org/posts/bounded-rands.html.
1190HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
1191 const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
1192 HWY_DASSERT(chunk_index < num_chunks);
1193 return static_cast<size_t>(chunk_index);
1194}
1195
1196// Writes samples from `keys[0, num)` into `buf`.
1197template <class D, class Traits, typename T>
1198HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
1199 T* HWY_RESTRICT buf, uint64_t* HWY_RESTRICT state) {
1200 using V = decltype(Zero(d));
1201 const size_t N = Lanes(d);
1202
1203 // Power of two
1204 constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T));
1205
1206 // Align start of keys to chunks. We have at least 2 chunks (x 64 bytes)
1207 // because the base case handles anything up to 8 vectors (x 16 bytes).
1208 HWY_DASSERT(num >= Constants::SampleLanes<T>());
1209 const size_t misalign =
1210 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (kLanesPerChunk - 1);
1211 if (misalign != 0) {
1212 const size_t consume = kLanesPerChunk - misalign;
1213 keys += consume;
1214 num -= consume;
1215 }
1216
1217 // Generate enough random bits for 6 uint32
1218 uint32_t bits[6];
1219 for (size_t i = 0; i < 6; i += 2) {
1220 const uint64_t bits64 = RandomBits(state);
1221 CopyBytes<8>(&bits64, bits + i);
1222 }
1223
1224 const size_t num_chunks64 = num / kLanesPerChunk;
1225 // Clamp to uint32 for RandomChunkIndex
1226 const uint32_t num_chunks =
1227 static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
1228
1229 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk;
1230 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk;
1231 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk;
1232 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk;
1233 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk;
1234 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk;
1235 for (size_t i = 0; i < kLanesPerChunk; i += N) {
1236 const V v0 = Load(d, keys + offset0 + i);
1237 const V v1 = Load(d, keys + offset1 + i);
1238 const V v2 = Load(d, keys + offset2 + i);
1239 const V medians0 = MedianOf3(st, v0, v1, v2);
1240 Store(medians0, d, buf + i);
1241
1242 const V v3 = Load(d, keys + offset3 + i);
1243 const V v4 = Load(d, keys + offset4 + i);
1244 const V v5 = Load(d, keys + offset5 + i);
1245 const V medians1 = MedianOf3(st, v3, v4, v5);
1246 Store(medians1, d, buf + i + kLanesPerChunk);
1247 }
1248}
1249
1250template <class V>
1251V OrXor(const V o, const V x1, const V x2) {
1252 return Or(o, Xor(x1, x2)); // TERNLOG on AVX3
1253}
1254
1255// For detecting inputs where (almost) all keys are equal.
1256template <class D, class Traits>
1257HWY_INLINE bool UnsortedSampleEqual(D d, Traits st,
1258 const TFromD<D>* HWY_RESTRICT samples) {
1259 constexpr size_t kSampleLanes = Constants::SampleLanes<TFromD<D>>();
1260 const size_t N = Lanes(d);
1261 // Both are powers of two, so there will be no remainders.
1262 HWY_DASSERT(N < kSampleLanes);
1263 using V = Vec<D>;
1264
1265 const V first = st.SetKey(d, samples);
1266
1267 if (!hwy::IsFloat<TFromD<D>>()) {
1268 // OR of XOR-difference may be faster than comparison.
1269 V diff = Zero(d);
1270 for (size_t i = 0; i < kSampleLanes; i += N) {
1271 const V v = Load(d, samples + i);
1272 diff = OrXor(diff, first, v);
1273 }
1274 return st.NoKeyDifference(d, diff);
1275 } else {
1276 // Disable the OrXor optimization for floats because OrXor will not treat
1277 // subnormals the same as actual comparisons, leading to logic errors for
1278 // 2-value cases.
1279 for (size_t i = 0; i < kSampleLanes; i += N) {
1280 const V v = Load(d, samples + i);
1281 if (!AllTrue(d, st.EqualKeys(d, v, first))) {
1282 return false;
1283 }
1284 }
1285 return true;
1286 }
1287}
1288
1289template <class D, class Traits, typename T>
1290HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) {
1291 const size_t N = Lanes(d);
1292 constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
1293 // Network must be large enough to sort two chunks.
1294 HWY_DASSERT(Constants::BaseCaseNumLanes<st.LanesPerKey()>(N) >= kSampleLanes);
1295
1296 BaseCase(d, st, buf, kSampleLanes, buf + kSampleLanes);
1297
1298 if (VQSORT_PRINT >= 2) {
1299 fprintf(stderr, "Samples:\n");
1300 for (size_t i = 0; i < kSampleLanes; i += N) {
1301 MaybePrintVector(d, "", Load(d, buf + i), 0, N);
1302 }
1303 }
1304}
1305
1306// ------------------------------ Pivot selection
1307
1308enum class PivotResult {
1309 kDone, // stop without partitioning (all equal, or two-value partition)
1310 kNormal, // partition and recurse left and right
1311 kIsFirst, // partition but skip left recursion
1312 kWasLast, // partition but skip right recursion
1313};
1314
1315HWY_INLINE const char* PivotResultString(PivotResult result) {
1316 switch (result) {
1317 case PivotResult::kDone:
1318 return "done";
1319 case PivotResult::kNormal:
1320 return "normal";
1321 case PivotResult::kIsFirst:
1322 return "first";
1323 case PivotResult::kWasLast:
1324 return "last";
1325 }
1326 return "unknown";
1327}
1328
1329// (Could vectorize, but only 0.2% of total time)
1330template <class Traits, typename T>
1331HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) {
1332 constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
1333 constexpr size_t N1 = st.LanesPerKey();
1334
1335 constexpr size_t kRankMid = kSampleLanes / 2;
1336 static_assert(kRankMid % N1 == 0, "Mid is not an aligned key");
1337
1338 // Find the previous value not equal to the median.
1339 size_t rank_prev = kRankMid - N1;
1340 for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) {
1341 // All previous samples are equal to the median.
1342 if (rank_prev == 0) return 0;
1343 }
1344
1345 size_t rank_next = rank_prev + N1;
1346 for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) {
1347 // The median is also the largest sample. If it is also the largest key,
1348 // we'd end up with an empty right partition, so choose the previous key.
1349 if (rank_next == kSampleLanes - N1) return rank_prev;
1350 }
1351
1352 // If we choose the median as pivot, the ratio of keys ending in the left
1353 // partition will likely be rank_next/kSampleLanes (if the sample is
1354 // representative). This is because equal-to-pivot values also land in the
1355 // left - it's infeasible to do an in-place vectorized 3-way partition.
1356 // Check whether prev would lead to a more balanced partition.
1357 const size_t excess_if_median = rank_next - kRankMid;
1358 const size_t excess_if_prev = kRankMid - rank_prev;
1359 return excess_if_median < excess_if_prev ? kRankMid : rank_prev;
1360}
1361
1362// Returns pivot chosen from `samples`. It will never be the largest key
1363// (thus the right partition will never be empty).
1364template <class D, class Traits, typename T>
1365HWY_INLINE Vec<D> ChoosePivotByRank(D d, Traits st,
1366 const T* HWY_RESTRICT samples) {
1367 const size_t pivot_rank = PivotRank(st, samples);
1368 const Vec<D> pivot = st.SetKey(d, samples + pivot_rank);
1369 if (VQSORT_PRINT >= 2) {
1370 fprintf(stderr, " Pivot rank %3zu\n", pivot_rank);
1371 HWY_ALIGN T pivot_lanes[MaxLanes(d)];
1372 Store(pivot, d, pivot_lanes);
1373 using KeyType = typename Traits::KeyType;
1374 KeyType key;
1375 CopyBytes<sizeof(KeyType)>(pivot_lanes, &key);
1376 PrintValue(key);
1377 }
1378 // Verify pivot is not equal to the last sample.
1379 constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
1380 constexpr size_t N1 = st.LanesPerKey();
1381 const Vec<D> last = st.SetKey(d, samples + kSampleLanes - N1);
1382 const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last));
1383 (void)all_neq;
1384 HWY_DASSERT(all_neq);
1385 return pivot;
1386}
1387
1388// Returns true if all keys equal `pivot`, otherwise returns false and sets
1389// `*first_mismatch' to the index of the first differing key.
1390template <class D, class Traits, typename T>
1391HWY_INLINE bool AllEqual(D d, Traits st, const Vec<D> pivot,
1392 const T* HWY_RESTRICT keys, size_t num,
1393 size_t* HWY_RESTRICT first_mismatch) {
1394 const size_t N = Lanes(d);
1395 // Ensures we can use overlapping loads for the tail; see HandleSpecialCases.
1396 HWY_DASSERT(num >= N);
1397 const Vec<D> zero = Zero(d);
1398
1399 // Vector-align keys + i.
1400 const size_t misalign =
1401 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (N - 1);
1402 HWY_DASSERT(misalign % st.LanesPerKey() == 0);
1403 const size_t consume = N - misalign;
1404 {
1405 const Vec<D> v = LoadU(d, keys);
1406 // Only check masked lanes; consider others to be equal.
1407 const Mask<D> diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot));
1408 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1409 const size_t lane = FindKnownFirstTrue(d, diff);
1410 *first_mismatch = lane;
1411 return false;
1412 }
1413 }
1414 size_t i = consume;
1415 HWY_DASSERT(((reinterpret_cast<uintptr_t>(keys + i) / sizeof(T)) & (N - 1)) ==
1416 0);
1417
1418 // Disable the OrXor optimization for floats because OrXor will not treat
1419 // subnormals the same as actual comparisons, leading to logic errors for
1420 // 2-value cases.
1421 if (!hwy::IsFloat<T>()) {
1422 // Sticky bits registering any difference between `keys` and the first key.
1423 // We use vector XOR because it may be cheaper than comparisons, especially
1424 // for 128-bit. 2x unrolled for more ILP.
1425 Vec<D> diff0 = zero;
1426 Vec<D> diff1 = zero;
1427
1428 // We want to stop once a difference has been found, but without slowing
1429 // down the loop by comparing during each iteration. The compromise is to
1430 // compare after a 'group', which consists of kLoops times two vectors.
1431 constexpr size_t kLoops = 8;
1432 const size_t lanes_per_group = kLoops * 2 * N;
1433
1434 if (num >= lanes_per_group) {
1435 for (; i <= num - lanes_per_group; i += lanes_per_group) {
1437 for (size_t loop = 0; loop < kLoops; ++loop) {
1438 const Vec<D> v0 = Load(d, keys + i + loop * 2 * N);
1439 const Vec<D> v1 = Load(d, keys + i + loop * 2 * N + N);
1440 diff0 = OrXor(diff0, v0, pivot);
1441 diff1 = OrXor(diff1, v1, pivot);
1442 }
1443
1444 // If there was a difference in the entire group:
1445 if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) {
1446 // .. then loop until the first one, with termination guarantee.
1447 for (;; i += N) {
1448 const Vec<D> v = Load(d, keys + i);
1449 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1450 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1451 const size_t lane = FindKnownFirstTrue(d, diff);
1452 *first_mismatch = i + lane;
1453 return false;
1454 }
1455 }
1456 }
1457 }
1458 }
1459 } // !hwy::IsFloat<T>()
1460
1461 // Whole vectors, no unrolling, compare directly
1462 for (; i <= num - N; i += N) {
1463 const Vec<D> v = Load(d, keys + i);
1464 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1465 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1466 const size_t lane = FindKnownFirstTrue(d, diff);
1467 *first_mismatch = i + lane;
1468 return false;
1469 }
1470 }
1471 // Always re-check the last (unaligned) vector to reduce branching.
1472 i = num - N;
1473 const Vec<D> v = LoadU(d, keys + i);
1474 const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
1475 if (HWY_UNLIKELY(!AllFalse(d, diff))) {
1476 const size_t lane = FindKnownFirstTrue(d, diff);
1477 *first_mismatch = i + lane;
1478 return false;
1479 }
1480
1481 if (VQSORT_PRINT >= 1) {
1482 fprintf(stderr, "All keys equal\n");
1483 }
1484 return true; // all equal
1485}
1486
1487// Called from 'two locations', but only one is active (IsKV is constexpr).
1488template <class D, class Traits, typename T>
1489HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys,
1490 size_t num, const Vec<D> pivot) {
1491 const size_t N = Lanes(d);
1492 HWY_DASSERT(num >= N); // See HandleSpecialCases
1493
1494 if (VQSORT_PRINT >= 2) {
1495 fprintf(stderr, "Scanning for before\n");
1496 }
1497
1498 size_t i = 0;
1499
1500 constexpr size_t kLoops = 16;
1501 const size_t lanes_per_group = kLoops * N;
1502
1503 Vec<D> first = pivot;
1504
1505 // Whole group, unrolled
1506 if (num >= lanes_per_group) {
1507 for (; i <= num - lanes_per_group; i += lanes_per_group) {
1509 for (size_t loop = 0; loop < kLoops; ++loop) {
1510 const Vec<D> curr = LoadU(d, keys + i + loop * N);
1511 first = st.First(d, first, curr);
1512 }
1513
1514 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) {
1515 if (VQSORT_PRINT >= 2) {
1516 fprintf(stderr, "Stopped scanning at end of group %zu\n",
1517 i + lanes_per_group);
1518 }
1519 return true;
1520 }
1521 }
1522 }
1523 // Whole vectors, no unrolling
1524 for (; i <= num - N; i += N) {
1525 const Vec<D> curr = LoadU(d, keys + i);
1526 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
1527 if (VQSORT_PRINT >= 2) {
1528 fprintf(stderr, "Stopped scanning at %zu\n", i);
1529 }
1530 return true;
1531 }
1532 }
1533 // If there are remainders, re-check the last whole vector.
1534 if (HWY_LIKELY(i != num)) {
1535 const Vec<D> curr = LoadU(d, keys + num - N);
1536 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
1537 if (VQSORT_PRINT >= 2) {
1538 fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
1539 }
1540 return true;
1541 }
1542 }
1543
1544 return false; // pivot is the first
1545}
1546
1547// Called from 'two locations', but only one is active (IsKV is constexpr).
1548template <class D, class Traits, typename T>
1549HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys,
1550 size_t num, const Vec<D> pivot) {
1551 const size_t N = Lanes(d);
1552 HWY_DASSERT(num >= N); // See HandleSpecialCases
1553
1554 if (VQSORT_PRINT >= 2) {
1555 fprintf(stderr, "Scanning for after\n");
1556 }
1557
1558 size_t i = 0;
1559
1560 constexpr size_t kLoops = 16;
1561 const size_t lanes_per_group = kLoops * N;
1562
1563 Vec<D> last = pivot;
1564
1565 // Whole group, unrolled
1566 if (num >= lanes_per_group) {
1567 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1569 for (size_t loop = 0; loop < kLoops; ++loop) {
1570 const Vec<D> curr = LoadU(d, keys + i + loop * N);
1571 last = st.Last(d, last, curr);
1572 }
1573
1574 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) {
1575 if (VQSORT_PRINT >= 2) {
1576 fprintf(stderr, "Stopped scanning at end of group %zu\n",
1577 i + lanes_per_group);
1578 }
1579 return true;
1580 }
1581 }
1582 }
1583 // Whole vectors, no unrolling
1584 for (; i <= num - N; i += N) {
1585 const Vec<D> curr = LoadU(d, keys + i);
1586 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
1587 if (VQSORT_PRINT >= 2) {
1588 fprintf(stderr, "Stopped scanning at %zu\n", i);
1589 }
1590 return true;
1591 }
1592 }
1593 // If there are remainders, re-check the last whole vector.
1594 if (HWY_LIKELY(i != num)) {
1595 const Vec<D> curr = LoadU(d, keys + num - N);
1596 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
1597 if (VQSORT_PRINT >= 2) {
1598 fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
1599 }
1600 return true;
1601 }
1602 }
1603
1604 return false; // pivot is the last
1605}
1606
1607// Returns pivot chosen from `keys[0, num)`. It will never be the largest key
1608// (thus the right partition will never be empty).
1609template <class D, class Traits, typename T>
1610HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D d, Traits st,
1611 T* HWY_RESTRICT keys, size_t num,
1612 T* HWY_RESTRICT samples,
1613 Vec<D> second, Vec<D> third,
1614 PivotResult& result) {
1615 const Vec<D> pivot = st.SetKey(d, samples); // the single unique sample
1616
1617 // Early out for mostly-0 arrays, where pivot is often FirstValue.
1618 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) {
1619 result = PivotResult::kIsFirst;
1620 if (VQSORT_PRINT >= 2) {
1621 fprintf(stderr, "Pivot equals first possible value\n");
1622 }
1623 return pivot;
1624 }
1625 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) {
1626 if (VQSORT_PRINT >= 2) {
1627 fprintf(stderr, "Pivot equals last possible value\n");
1628 }
1629 result = PivotResult::kWasLast;
1630 return st.PrevValue(d, pivot);
1631 }
1632
1633 // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and
1634 // cannot be used.
1635 if (st.IsKV()) {
1636 // If true, pivot is either middle or last.
1637 const bool before = !AllFalse(d, st.Compare(d, second, pivot));
1638 if (HWY_UNLIKELY(before)) {
1639 // Not last, so middle.
1640 if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) {
1641 result = PivotResult::kNormal;
1642 return pivot;
1643 }
1644
1645 // We didn't find anything after pivot, so it is the last. Because keys
1646 // equal to the pivot go to the left partition, the right partition would
1647 // be empty and Partition will not have changed anything. Instead use the
1648 // previous value in sort order, which is not necessarily an actual key.
1649 result = PivotResult::kWasLast;
1650 return st.PrevValue(d, pivot);
1651 }
1652
1653 // Otherwise, pivot is first or middle. Rule out it being first:
1654 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
1655 result = PivotResult::kNormal;
1656 return pivot;
1657 }
1658 // It is first: fall through to shared code below.
1659 } else {
1660 // Check if pivot is between two known values. If so, it is not the first
1661 // nor the last and we can avoid scanning.
1662 st.Sort2(d, second, third);
1663 HWY_DASSERT(AllTrue(d, st.Compare(d, second, third)));
1664 const bool before = !AllFalse(d, st.Compare(d, second, pivot));
1665 const bool after = !AllFalse(d, st.Compare(d, pivot, third));
1666 // Only reached if there are three keys, which means pivot is either first,
1667 // last, or in between. Thus there is another key that comes before or
1668 // after.
1669 HWY_DASSERT(before || after);
1670 if (HWY_UNLIKELY(before)) {
1671 // Neither first nor last.
1672 if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) {
1673 result = PivotResult::kNormal;
1674 return pivot;
1675 }
1676
1677 // We didn't find anything after pivot, so it is the last. Because keys
1678 // equal to the pivot go to the left partition, the right partition would
1679 // be empty and Partition will not have changed anything. Instead use the
1680 // previous value in sort order, which is not necessarily an actual key.
1681 result = PivotResult::kWasLast;
1682 return st.PrevValue(d, pivot);
1683 }
1684
1685 // Has after, and we found one before: in the middle.
1686 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
1687 result = PivotResult::kNormal;
1688 return pivot;
1689 }
1690 }
1691
1692 // Pivot is first. We could consider a special partition mode that only
1693 // reads from and writes to the right side, and later fills in the left
1694 // side, which we know is equal to the pivot. However, that leads to more
1695 // cache misses if the array is large, and doesn't save much, hence is a
1696 // net loss.
1697 result = PivotResult::kIsFirst;
1698 return pivot;
1699}
1700
1701// ------------------------------ Quicksort recursion
1702
1703enum class RecurseMode {
1704 kSort, // Sort mode.
1705 kSelect, // Select mode.
1706 // The element pointed at by nth is changed to whatever element
1707 // would occur in that position if [first, last) were sorted. All of
1708 // the elements before this new nth element are less than or equal
1709 // to the elements after the new nth element.
1710 kLooseSelect, // Loose select mode.
1711 // The first n elements will contain the n smallest elements in
1712 // unspecified order
1713};
1714
1715template <class D, class Traits, typename T>
1716HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
1717 size_t num, T* HWY_RESTRICT buf) {
1718 if (VQSORT_PRINT >= 2) {
1719 const size_t N = Lanes(d);
1720 if (num < N) return;
1721
1722 Vec<D> first = st.LastValue(d);
1723 Vec<D> last = st.FirstValue(d);
1724
1725 size_t i = 0;
1726 for (; i <= num - N; i += N) {
1727 const Vec<D> v = LoadU(d, keys + i);
1728 first = st.First(d, v, first);
1729 last = st.Last(d, v, last);
1730 }
1731 if (HWY_LIKELY(i != num)) {
1732 HWY_DASSERT(num >= N); // See HandleSpecialCases
1733 const Vec<D> v = LoadU(d, keys + num - N);
1734 first = st.First(d, v, first);
1735 last = st.Last(d, v, last);
1736 }
1737
1738 first = st.FirstOfLanes(d, first, buf);
1739 last = st.LastOfLanes(d, last, buf);
1740 MaybePrintVector(d, "first", first, 0, st.LanesPerKey());
1741 MaybePrintVector(d, "last", last, 0, st.LanesPerKey());
1742 }
1743}
1744
1745template <RecurseMode mode, class D, class Traits, typename T>
1746HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys,
1747 const size_t num, T* HWY_RESTRICT buf,
1748 uint64_t* HWY_RESTRICT state,
1749 const size_t remaining_levels, const size_t k = 0) {
1750 HWY_DASSERT(num != 0);
1751
1752 const size_t N = Lanes(d);
1753 constexpr size_t kLPK = st.LanesPerKey();
1754 if (HWY_UNLIKELY(num <= Constants::BaseCaseNumLanes<kLPK>(N))) {
1755 BaseCase(d, st, keys, num, buf);
1756 return;
1757 }
1758
1759 // Move after BaseCase so we skip printing for small subarrays.
1760 if (VQSORT_PRINT >= 1) {
1761 fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu\n", remaining_levels,
1762 num);
1763 PrintMinMax(d, st, keys, num, buf);
1764 }
1765
1766 DrawSamples(d, st, keys, num, buf, state);
1767
1768 Vec<D> pivot;
1769 PivotResult result = PivotResult::kNormal;
1770 if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) {
1771 pivot = st.SetKey(d, buf);
1772 size_t idx_second = 0;
1773 if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) {
1774 return;
1775 }
1776 HWY_DASSERT(idx_second % st.LanesPerKey() == 0);
1777 // Must capture the value before PartitionIfTwoKeys may overwrite it.
1778 const Vec<D> second = st.SetKey(d, keys + idx_second);
1779 MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey());
1780 MaybePrintVector(d, "second", second, 0, st.LanesPerKey());
1781
1782 Vec<D> third;
1783 // Not supported for key-value types because two 'keys' may be equivalent
1784 // but not interchangeable (their values may differ).
1785 if (HWY_UNLIKELY(!st.IsKV() &&
1786 PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second,
1787 second, third, buf))) {
1788 return; // Done, skip recursion because each side has all-equal keys.
1789 }
1790
1791 // We can no longer start scanning from idx_second because
1792 // PartitionIfTwoKeys may have reordered keys.
1793 pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third,
1794 result);
1795 // If kNormal, `pivot` is very common but not the first/last. It is
1796 // tempting to do a 3-way partition (to avoid moving the =pivot keys a
1797 // second time), but that is a net loss due to the extra comparisons.
1798 } else {
1799 SortSamples(d, st, buf);
1800
1801 // Not supported for key-value types because two 'keys' may be equivalent
1802 // but not interchangeable (their values may differ).
1803 if (HWY_UNLIKELY(!st.IsKV() &&
1804 PartitionIfTwoSamples(d, st, keys, num, buf))) {
1805 return;
1806 }
1807
1808 pivot = ChoosePivotByRank(d, st, buf);
1809 }
1810
1811 // Too many recursions. This is unlikely to happen because we select pivots
1812 // from large (though still O(1)) samples.
1813 if (HWY_UNLIKELY(remaining_levels == 0)) {
1814 if (VQSORT_PRINT >= 1) {
1815 fprintf(stderr, "HeapSort reached, size=%zu\n", num);
1816 }
1817 HeapSort(st, keys, num); // Slow but N*logN.
1818 return;
1819 }
1820
1821 const size_t bound = Partition(d, st, keys, num, pivot, buf);
1822 if (VQSORT_PRINT >= 2) {
1823 fprintf(stderr, "bound %zu num %zu result %s\n", bound, num,
1824 PivotResultString(result));
1825 }
1826 // The left partition is not empty because the pivot is usually one of the
1827 // keys. Exception: if kWasLast, we set pivot to PrevValue(pivot), but we
1828 // still have at least one value <= pivot because AllEqual ruled out the case
1829 // of only one unique value. Note that for floating-point, PrevValue can
1830 // return the same value (for -inf inputs), but that would just mean the
1831 // pivot is again one of the keys.
1832 HWY_DASSERT(bound != 0);
1833 // ChoosePivot* ensure pivot != last, so the right partition is never empty
1834 // except in the rare case of the pivot matching the last-in-sort-order value,
1835 // which implies we anyway skip the right partition due to kWasLast.
1836 HWY_DASSERT(bound != num || result == PivotResult::kWasLast);
1837
1838 HWY_IF_CONSTEXPR(mode == RecurseMode::kSelect) {
1839 if (HWY_LIKELY(result != PivotResult::kIsFirst) && k < bound) {
1840 Recurse<RecurseMode::kSelect>(d, st, keys, bound, buf, state,
1841 remaining_levels - 1, k);
1842 } else if (HWY_LIKELY(result != PivotResult::kWasLast) && k >= bound) {
1843 Recurse<RecurseMode::kSelect>(d, st, keys + bound, num - bound, buf,
1844 state, remaining_levels - 1, k - bound);
1845 }
1846 }
1847 HWY_IF_CONSTEXPR(mode == RecurseMode::kSort) {
1848 if (HWY_LIKELY(result != PivotResult::kIsFirst)) {
1849 Recurse<RecurseMode::kSort>(d, st, keys, bound, buf, state,
1850 remaining_levels - 1);
1851 }
1852 if (HWY_LIKELY(result != PivotResult::kWasLast)) {
1853 Recurse<RecurseMode::kSort>(d, st, keys + bound, num - bound, buf, state,
1854 remaining_levels - 1);
1855 }
1856 }
1857}
1858
1859// Returns true if sorting is finished.
1860template <class D, class Traits, typename T>
1861HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys,
1862 size_t num, T* HWY_RESTRICT buf) {
1863 const size_t N = Lanes(d);
1864 constexpr size_t kLPK = st.LanesPerKey();
1865 const size_t base_case_num = Constants::BaseCaseNumLanes<kLPK>(N);
1866
1867 // Recurse will also check this, but doing so here first avoids setting up
1868 // the random generator state.
1869 if (HWY_UNLIKELY(num <= base_case_num)) {
1870 if (VQSORT_PRINT >= 1) {
1871 fprintf(stderr, "Special-casing small, %d lanes\n",
1872 static_cast<int>(num));
1873 }
1874 BaseCase(d, st, keys, num, buf);
1875 return true;
1876 }
1877
1878 // 128-bit keys require vectors with at least two u64 lanes, which is always
1879 // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
1880 // hardware vector width is less than 128bit / fraction.
1881 const bool partial_128 = !IsFull(d) && N < 2 && st.Is128();
1882 // Partition assumes its input is at least two vectors. If vectors are huge,
1883 // base_case_num may actually be smaller. If so, which is only possible on
1884 // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of
1885 // HWY_LANES to account for the largest possible LMUL.
1886 constexpr bool kPotentiallyHuge =
1888 const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
1889 if (partial_128 || huge_vec) {
1890 if (VQSORT_PRINT >= 1) {
1891 fprintf(stderr, "WARNING: using slow HeapSort: partial %d huge %d\n",
1892 partial_128, huge_vec);
1893 }
1894 HeapSort(st, keys, num);
1895 return true;
1896 }
1897
1898 // We could also check for already sorted/reverse/equal, but that's probably
1899 // counterproductive if vqsort is used as a base case.
1900
1901 return false; // not finished sorting
1902}
1903
1904#endif // VQSORT_ENABLED
1905
1906template <class D, class Traits, typename T, HWY_IF_FLOAT(T)>
1907HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys,
1908 size_t num) {
1909 const size_t N = Lanes(d);
1910 // Will be sorted to the back of the array.
1911 const Vec<D> sentinel = st.LastValue(d);
1912 size_t num_nan = 0;
1913 size_t i = 0;
1914 if (num >= N) {
1915 for (; i <= num - N; i += N) {
1916 const Mask<D> is_nan = IsNaN(LoadU(d, keys + i));
1917 BlendedStore(sentinel, is_nan, d, keys + i);
1918 num_nan += CountTrue(d, is_nan);
1919 }
1920 }
1921
1922 const size_t remaining = num - i;
1923 HWY_DASSERT(remaining < N);
1924 const Vec<D> v = LoadN(d, keys + i, remaining);
1925 const Mask<D> is_nan = IsNaN(v);
1926 StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining);
1927 num_nan += CountTrue(d, is_nan);
1928 return num_nan;
1929}
1930
1931// IsNaN is not implemented for non-float, so skip it.
1932template <class D, class Traits, typename T, HWY_IF_NOT_FLOAT(T)>
1933HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) {
1934 return 0;
1935}
1936
1937} // namespace detail
1938
1939// Old interface with user-specified buffer, retained for compatibility. Called
1940// by the newer overload below. `buf` must be vector-aligned and hold at least
1941// SortConstants::BufBytes(HWY_MAX_BYTES, st.LanesPerKey()).
1942template <class D, class Traits, typename T>
1943void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
1944 T* HWY_RESTRICT buf) {
1945 if (VQSORT_PRINT >= 1) {
1946 fprintf(stderr,
1947 "=============== Sort num %zu is128 %d isKV %d vec bytes %d\n", num,
1948 st.Is128(), st.IsKV(), static_cast<int>(sizeof(T) * Lanes(d)));
1949 }
1950
1951#if HWY_MAX_BYTES > 64
1952 // sorting_networks-inl and traits assume no more than 512 bit vectors.
1953 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
1954 return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
1955 }
1956#endif // HWY_MAX_BYTES > 64
1957
1958 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
1959
1960#if VQSORT_ENABLED || HWY_IDE
1961 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) {
1963 // Introspection: switch to worst-case N*logN heapsort after this many.
1964 // Should never be reached, so computing log2 exactly does not help.
1965 const size_t max_levels = 50;
1966 detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, num, buf, state,
1967 max_levels);
1968 }
1969#else // !VQSORT_ENABLED
1970 (void)d;
1971 (void)buf;
1972 if (VQSORT_PRINT >= 1) {
1973 fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
1974 }
1975 detail::HeapSort(st, keys, num);
1976#endif // VQSORT_ENABLED
1977
1978 if (num_nan != 0) {
1979 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
1980 }
1981}
1982
1983template <class D, class Traits, typename T>
1984void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
1985 const size_t k, T* HWY_RESTRICT buf) {
1986 if (VQSORT_PRINT >= 1) {
1987 fprintf(stderr, "=============== Select num=%zu, vec bytes=%d\n", num,
1988 static_cast<int>(sizeof(T) * Lanes(d)));
1989 }
1990
1991#if HWY_MAX_BYTES > 64
1992 // sorting_networks-inl and traits assume no more than 512 bit vectors.
1993 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
1994 return Select(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf);
1995 }
1996#endif // HWY_MAX_BYTES > 64
1997
1998 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
1999
2000#if VQSORT_ENABLED || HWY_IDE
2001 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO
2003 // Introspection: switch to worst-case N*logN heapsort after this many.
2004 // Should never be reached, so computing log2 exactly does not help.
2005 const size_t max_levels = 50;
2006 detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state,
2007 max_levels, k);
2008 }
2009#else // !VQSORT_ENABLED
2010 (void)d;
2011 (void)buf;
2012 if (VQSORT_PRINT >= 1) {
2013 fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
2014 }
2015 detail::HeapSelect(st, keys, num, k);
2016#endif // VQSORT_ENABLED
2017
2018 if (num_nan != 0) {
2019 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
2020 }
2021}
2022
2023template <class D, class Traits, typename T>
2024void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k,
2025 T* HWY_RESTRICT buf) {
2026 if (VQSORT_PRINT >= 1) {
2027 fprintf(stderr, "=============== PartialSort num=%zu, vec bytes=%d\n", num,
2028 static_cast<int>(sizeof(T) * Lanes(d)));
2029 }
2030
2031#if HWY_MAX_BYTES > 64
2032 // sorting_networks-inl and traits assume no more than 512 bit vectors.
2033 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
2034 return PartialSort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf);
2035 }
2036#endif // HWY_MAX_BYTES > 64
2037
2038 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
2039
2040#if VQSORT_ENABLED || HWY_IDE
2041 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO
2043 // Introspection: switch to worst-case N*logN heapsort after this many.
2044 // Should never be reached, so computing log2 exactly does not help.
2045 const size_t max_levels = 50;
2046 // TODO: optimize to use kLooseSelect
2047 detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state,
2048 max_levels, k);
2049 detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, k, buf, state,
2050 max_levels);
2051 }
2052#else // !VQSORT_ENABLED
2053 (void)d;
2054 (void)buf;
2055 if (VQSORT_PRINT >= 1) {
2056 fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
2057 }
2058 detail::HeapPartialSort(st, keys, num, k);
2059#endif // VQSORT_ENABLED
2060
2061 if (num_nan != 0) {
2062 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
2063 }
2064}
2065
2066// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
2067// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
2068// Non-stable (order of equal keys may change), except for the common case where
2069// the upper bits of T are the key, and the lower bits are a sequential or at
2070// least unique ID. Any NaN will be moved to the back of the array and replaced
2071// with the canonical NaN(d).
2072// There is no upper limit on `num`, but note that pivots may be chosen by
2073// sampling only from the first 256 GiB.
2074//
2075// `d` is typically SortTag<T> (chooses between full and partial vectors).
2076// `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
2077// differences in sort order and single-lane vs 128-bit keys.
2078template <class D, class Traits, typename T>
2079HWY_API void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num) {
2080 constexpr size_t kLPK = st.LanesPerKey();
2081 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
2082 return Sort(d, st, keys, num, buf);
2083}
2084
2085// Rearranges elements such that the range [0, k) contains the sorted k − first
2086// smallest elements in the range [0, n) ordered by `st.Compare`.
2087template <class D, class Traits, typename T>
2088HWY_API void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
2089 const size_t k) {
2090 HWY_ASSERT(k < num);
2091 constexpr size_t kLPK = st.LanesPerKey();
2092 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
2093 PartialSort(d, st, keys, num, k, buf);
2094}
2095
2096// Reorders `keys[0..num-1]` such that `keys[k]` is the k-th element if keys was
2097// sorted by `st.Compare`, and all of the elements before it are ordered
2098// by `st.Compare` relative to `keys[k]`. Rest as above, for Sort.
2099template <class D, class Traits, typename T>
2100HWY_API void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
2101 const size_t k) {
2102 HWY_ASSERT(k < num);
2103 constexpr size_t kLPK = st.LanesPerKey();
2104 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
2105 Select(d, st, keys, num, k, buf);
2106}
2107
2108#if VQSORT_ENABLED
2109// Adapter from VQSort[Static] to SortTag and Traits*/Order*.
2110namespace detail {
2111
2112// Primary template for built-in key types
2113template <typename T>
2114struct KeyAdapter {
2115 using Ascending = OrderAscending<T>;
2116 using Descending = OrderDescending<T>;
2117
2118 template <class Order>
2119 using Traits = TraitsLane<Order>;
2120};
2121
2122template <>
2123struct KeyAdapter<hwy::uint128_t> {
2124 using Ascending = OrderAscending128;
2125 using Descending = OrderDescending128;
2126
2127 template <class Order>
2128 using Traits = Traits128<Order>;
2129};
2130
2131template <>
2132struct KeyAdapter<hwy::K64V64> {
2133 using Ascending = OrderAscendingKV128;
2134 using Descending = OrderDescendingKV128;
2135
2136 template <class Order>
2137 using Traits = Traits128<Order>;
2138};
2139
2140template <>
2141struct KeyAdapter<hwy::K32V32> {
2142 using Ascending = OrderAscendingKV64;
2143 using Descending = OrderDescendingKV64;
2144
2145 template <class Order>
2146 using Traits = TraitsLane<Order>;
2147};
2148
2149} // namespace detail
2150#endif // VQSORT_ENABLED
2151
2152// Simpler interface matching VQSort(), but without dynamic dispatch. Uses the
2153// instructions available in the current target (HWY_NAMESPACE). Supported key
2154// types: 16-64 bit unsigned/signed/floating-point (but float64 only #if
2155// HWY_HAVE_FLOAT64), uint128_t, K64V64, K32V32.
2156template <typename T>
2157void VQSortStatic(T* HWY_RESTRICT keys, const size_t num, SortAscending) {
2158#if VQSORT_ENABLED
2159 using Adapter = detail::KeyAdapter<T>;
2160 using Order = typename Adapter::Ascending;
2162 using LaneType = typename decltype(st)::LaneType;
2163 const SortTag<LaneType> d;
2164 Sort(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey());
2165#else
2166 (void)keys;
2167 (void)num;
2168 HWY_ASSERT(0);
2169#endif // VQSORT_ENABLED
2170}
2171
2172template <typename T>
2173void VQSortStatic(T* HWY_RESTRICT keys, const size_t num, SortDescending) {
2174#if VQSORT_ENABLED
2175 using Adapter = detail::KeyAdapter<T>;
2176 using Order = typename Adapter::Descending;
2178 using LaneType = typename decltype(st)::LaneType;
2179 const SortTag<LaneType> d;
2180 Sort(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey());
2181#else
2182 (void)keys;
2183 (void)num;
2184 HWY_ASSERT(0);
2185#endif // VQSORT_ENABLED
2186}
2187
2188template <typename T>
2189void VQPartialSortStatic(T* HWY_RESTRICT keys, const size_t num, const size_t k,
2190 SortAscending) {
2191#if VQSORT_ENABLED
2192 using Adapter = detail::KeyAdapter<T>;
2193 using Order = typename Adapter::Ascending;
2195 using LaneType = typename decltype(st)::LaneType;
2196 const SortTag<LaneType> d;
2197 PartialSort(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey(),
2198 k);
2199#else
2200 (void)keys;
2201 (void)num;
2202 HWY_ASSERT(0);
2203#endif // VQSORT_ENABLED
2204}
2205
2206template <typename T>
2207void VQPartialSortStatic(T* HWY_RESTRICT keys, const size_t num, const size_t k,
2209#if VQSORT_ENABLED
2210 using Adapter = detail::KeyAdapter<T>;
2211 using Order = typename Adapter::Descending;
2213 using LaneType = typename decltype(st)::LaneType;
2214 const SortTag<LaneType> d;
2215 PartialSort(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey(),
2216 k);
2217#else
2218 (void)keys;
2219 (void)num;
2220 HWY_ASSERT(0);
2221#endif // VQSORT_ENABLED
2222}
2223
2224template <typename T>
2225void VQSelectStatic(T* HWY_RESTRICT keys, const size_t num, const size_t k,
2226 SortAscending) {
2227#if VQSORT_ENABLED
2228 using Adapter = detail::KeyAdapter<T>;
2229 using Order = typename Adapter::Ascending;
2231 using LaneType = typename decltype(st)::LaneType;
2232 const SortTag<LaneType> d;
2233 Select(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey(), k);
2234#else
2235 (void)keys;
2236 (void)num;
2237 HWY_ASSERT(0);
2238#endif // VQSORT_ENABLED
2239}
2240
2241template <typename T>
2242void VQSelectStatic(T* HWY_RESTRICT keys, const size_t num, const size_t k,
2244#if VQSORT_ENABLED
2245 using Adapter = detail::KeyAdapter<T>;
2246 using Order = typename Adapter::Descending;
2248 using LaneType = typename decltype(st)::LaneType;
2249 const SortTag<LaneType> d;
2250 Select(d, st, reinterpret_cast<LaneType*>(keys), num * st.LanesPerKey(), k);
2251#else
2252 (void)keys;
2253 (void)num;
2254 HWY_ASSERT(0);
2255#endif // VQSORT_ENABLED
2256}
2257
2258// NOLINTNEXTLINE(google-readability-namespace-comments)
2259} // namespace HWY_NAMESPACE
2260} // namespace hwy
2262
2263#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
#define HWY_RESTRICT
Definition base.h:95
#define HWY_ASSUME(expr)
Definition base.h:214
#define HWY_NOINLINE
Definition base.h:103
#define HWY_API
Definition base.h:171
#define HWY_MIN(a, b)
Definition base.h:176
#define HWY_IF_CONSTEXPR
Definition base.h:310
#define HWY_INLINE
Definition base.h:101
#define HWY_DASSERT(condition)
Definition base.h:290
#define HWY_DEFAULT_UNROLL
Definition base.h:188
#define HWY_UNROLL(factor)
Definition base.h:187
#define HWY_ASSERT(condition)
Definition base.h:237
#define HWY_LIKELY(expr)
Definition base.h:106
#define HWY_UNLIKELY(expr)
Definition base.h:107
HWY_INLINE void MaybeUnpoison(T *HWY_RESTRICT unaligned, size_t count)
Definition ops/shared-inl.h:151
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition vqsort-inl.h:131
HWY_INLINE void MaybePrintVector(D d, const char *label, Vec< D > v, size_t start=0, size_t max_lanes=16)
Definition vqsort-inl.h:115
HWY_INLINE Mask128< T > Not(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition x86_128-inl.h:1653
HWY_INLINE Mask128< T, N > ExclusiveNeither(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:1593
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition wasm_128-inl.h:5084
HWY_INLINE Vec128< T, N > Add(hwy::NonFloatTag, Vec128< T, N > a, Vec128< T, N > b)
Definition emu128-inl.h:560
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:1445
HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T *HWY_RESTRICT keys, size_t num)
Definition vqsort-inl.h:1907
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition vqsort-inl.h:159
HWY_INLINE bool AllFalse(hwy::SizeTag< 1 >, const Mask256< T > mask)
Definition x86_256-inl.h:7076
HWY_INLINE Mask128< T, N > Or(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:1519
void HeapSelect(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, const size_t select)
Definition vqsort-inl.h:179
HWY_INLINE Vec128< T, N > Sub(hwy::NonFloatTag, Vec128< T, N > a, Vec128< T, N > b)
Definition emu128-inl.h:570
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:1482
void HeapPartialSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, const size_t select)
Definition vqsort-inl.h:208
HWY_API void StoreN(size_t count, VFromD< D > v, D d, TFromD< D > *HWY_RESTRICT p)
Definition rvv-inl.h:1926
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition x86_128-inl.h:1269
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, Mask128< T > mask)
Definition arm_neon-inl.h:8296
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition ops/shared-inl.h:325
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition x86_128-inl.h:1556
void VQSortStatic(T *HWY_RESTRICT keys, const size_t num, SortAscending)
Definition vqsort-inl.h:2157
HWY_API Mask128< T, N > IsNaN(const Vec128< T, N > v)
Definition arm_neon-inl.h:5093
D d
Definition arm_sve-inl.h:1915
HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D)
Definition ops/shared-inl.h:442
HWY_API Vec< D > NaN(D d)
Definition generic_ops-inl.h:82
HWY_API void Store(VFromD< D > v, D d, TFromD< D > *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:3911
HWY_API Vec128< uint8_t > LoadU(D, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3442
void PartialSort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, size_t k, T *HWY_RESTRICT buf)
Definition vqsort-inl.h:2024
HWY_API VFromD< D > MaskedLoadOr(VFromD< D > v, MFromD< D > m, D d, const TFromD< D > *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:3675
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
void VQPartialSortStatic(T *HWY_RESTRICT keys, const size_t num, const size_t k, SortAscending)
Definition vqsort-inl.h:2189
HWY_API void StoreU(Vec128< uint8_t > v, D, uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3689
HWY_API VFromD< D > Load(D d, const TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3664
HWY_API size_t CompressStore(VFromD< D > v, MFromD< D > mask, D d, TFromD< D > *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:8946
void Sort(D d, Traits st, T *HWY_RESTRICT keys, const size_t num, T *HWY_RESTRICT buf)
Definition vqsort-inl.h:1943
void VQSelectStatic(T *HWY_RESTRICT keys, const size_t num, const size_t k, SortAscending)
Definition vqsort-inl.h:2225
typename detail::CappedTagChecker< T, kLimit, kPow2 >::type CappedTag
Definition ops/shared-inl.h:379
HWY_API VFromD< D > Iota(D d, const T2 first)
Definition arm_neon-inl.h:1297
HWY_API void BlendedStore(VFromD< D > v, MFromD< D > m, D d, TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3918
HWY_API svbool_t Gt(const V a, const V b)
Definition arm_sve-inl.h:1578
ScalableTag< T, -1 > SortTag
Definition contrib/sort/shared-inl.h:146
decltype(MaskFromVec(Zero(D()))) Mask
Definition generic_ops-inl.h:52
HWY_INLINE Vec128< TFromD< D > > Set(D, T t)
Definition arm_neon-inl.h:931
void Fill(D d, T value, size_t count, T *HWY_RESTRICT to)
Definition copy-inl.h:42
HWY_API VFromD< D > LoadN(D d, const TFromD< D > *HWY_RESTRICT p, size_t max_lanes_to_load)
Definition emu128-inl.h:1352
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition ops/shared-inl.h:407
HWY_API TFromV< V > GetLane(const V v)
Definition arm_neon-inl.h:1648
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition generic_ops-inl.h:187
void Select(D d, Traits st, T *HWY_RESTRICT keys, const size_t num, const size_t k, T *HWY_RESTRICT buf)
Definition vqsort-inl.h:1984
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:46
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API MFromD< D > FirstN(D d, size_t num)
Definition arm_neon-inl.h:3232
decltype(GetLane(V())) LaneType
Definition generic_ops-inl.h:39
HWY_API size_t CompressBlendedStore(VFromD< D > v, MFromD< D > m, D d, TFromD< D > *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:8955
HWY_API void Print(const D d, const char *caption, V v, size_t lane_u=0, size_t max_lanes=7)
Definition print-inl.h:39
HWY_API size_t FindKnownFirstTrue(D d, MFromD< D > mask)
Definition arm_neon-inl.h:8370
HWY_INLINE void Fill16BytesStatic(void *bytes)
Definition vqsort-inl.h:52
HWY_INLINE uint64_t * GetGeneratorStateStatic()
Definition vqsort-inl.h:70
Definition abort.h:8
HWY_API void CopyBytes(const From *from, To *to)
Definition base.h:327
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition cache_control.h:82
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition base.h:2577
HWY_CONTRIB_DLLEXPORT bool Fill16BytesSecure(void *bytes)
HWY_API constexpr bool IsFloat()
Definition base.h:2127
HWY_NOINLINE void PrintValue(T value)
Definition print.h:61
#define HWY_MAX_BYTES
Definition set_macros-inl.h:168
#define HWY_ALIGN
Definition set_macros-inl.h:167
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
Definition arm_neon-inl.h:8428
Definition sorting_networks-inl.h:893
Definition order.h:25
Definition contrib/sort/shared-inl.h:28
static constexpr HWY_INLINE size_t BaseCaseNumLanes(size_t N)
Definition contrib/sort/shared-inl.h:47
static constexpr size_t kMaxCols
Definition contrib/sort/shared-inl.h:36
static constexpr HWY_INLINE size_t PartitionBufNum(size_t N)
Definition contrib/sort/shared-inl.h:73
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t)
Definition contrib/sort/shared-inl.h:64
static constexpr size_t kMaxRows
Definition contrib/sort/shared-inl.h:43
static constexpr size_t kPartitionUnroll
Definition contrib/sort/shared-inl.h:58
Definition order.h:28
HWY_AFTER_NAMESPACE()
#define VQSORT_PRINT
Definition vqsort-inl.h:42
HWY_BEFORE_NAMESPACE()