Grok 12.0.1
algo-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
19
20#include <stdint.h>
21
22#include <algorithm> // std::sort, std::min, std::max
23#include <functional> // std::less, std::greater
24#include <vector>
25
26#include "hwy/base.h"
28#include "hwy/print.h"
29
30// Third-party algorithms
31#define HAVE_AVX2SORT 0
32#define HAVE_IPS4O 0
33// When enabling, consider changing max_threads (required for Table 1a)
34#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1)
35#define HAVE_PDQSORT 0
36#define HAVE_SORT512 0
37#define HAVE_VXSORT 0
38#if HWY_ARCH_X86
39#define HAVE_INTEL 0
40#else
41#define HAVE_INTEL 0
42#endif
43
44#if HAVE_PARALLEL_IPS4O
45#include <thread> // NOLINT
46#endif
47
48#if HAVE_AVX2SORT
49HWY_PUSH_ATTRIBUTES("avx2,avx")
50#include "avx2sort.h" //NOLINT
52#endif
53#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O
54#include "third_party/ips4o/include/ips4o.hpp"
55#include "third_party/ips4o/include/ips4o/thread_pool.hpp"
56#endif
57#if HAVE_PDQSORT
58#include "third_party/boost/allowed/sort/sort.hpp"
59#endif
60#if HAVE_SORT512
61#include "sort512.h" //NOLINT
62#endif
63
64// vxsort is difficult to compile for multiple targets because it also uses
65// .cpp files, and we'd also have to #undef its include guards. Instead, compile
66// only for AVX2 or AVX3 depending on this macro.
67#define VXSORT_AVX3 1
68#if HAVE_VXSORT
69// inlined from vxsort_targets_enable_avx512 (must close before end of header)
70#ifdef __GNUC__
71#ifdef __clang__
72#if VXSORT_AVX3
73#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \
74 apply_to = any(function))
75#else
76#pragma clang attribute push(__attribute__((target("avx2"))), \
77 apply_to = any(function))
78#endif // VXSORT_AVX3
79
80#else
81#pragma GCC push_options
82#if VXSORT_AVX3
83#pragma GCC target("avx512f,avx512dq")
84#else
85#pragma GCC target("avx2")
86#endif // VXSORT_AVX3
87#endif
88#endif
89
90#if VXSORT_AVX3
91#include "vxsort/machine_traits.avx512.h"
92#else
93#include "vxsort/machine_traits.avx2.h"
94#endif // VXSORT_AVX3
95#include "vxsort/vxsort.h"
96#ifdef __GNUC__
97#ifdef __clang__
98#pragma clang attribute pop
99#else
100#pragma GCC pop_options
101#endif
102#endif
103#endif // HAVE_VXSORT
104
105namespace hwy {
106
108
109static inline std::vector<Dist> AllDist() {
110 return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32};
111}
112
113static inline const char* DistName(Dist dist) {
114 switch (dist) {
115 case Dist::kUniform8:
116 return "uniform8";
117 case Dist::kUniform16:
118 return "uniform16";
119 case Dist::kUniform32:
120 return "uniform32";
121 }
122 return "unreachable";
123}
124
125template <typename T>
127 public:
128 void Notify(T value) {
129 min_ = std::min(min_, value);
130 max_ = std::max(max_, value);
131 // Converting to integer would truncate floats, multiplying to save digits
132 // risks overflow especially when casting, so instead take the sum of the
133 // bit representations as the checksum.
134 uint64_t bits = 0;
135 static_assert(sizeof(T) <= 8, "Expected a built-in type");
136 CopyBytes<sizeof(T)>(&value, &bits); // not same size
137 sum_ += bits;
138 count_ += 1;
139 }
140
141 bool operator==(const InputStats& other) const {
142 char type_name[100];
143 detail::TypeName(hwy::detail::MakeTypeInfo<T>(), 1, type_name);
144
145 if (count_ != other.count_) {
146 HWY_ABORT("Sort %s: count %d vs %d\n", type_name,
147 static_cast<int>(count_), static_cast<int>(other.count_));
148 }
149
150 if (min_ != other.min_ || max_ != other.max_) {
151 HWY_ABORT("Sort %s: minmax %f/%f vs %f/%f\n", type_name,
152 static_cast<double>(min_), static_cast<double>(max_),
153 static_cast<double>(other.min_),
154 static_cast<double>(other.max_));
155 }
156
157 // Sum helps detect duplicated/lost values
158 if (sum_ != other.sum_) {
159 HWY_ABORT("Sort %s: Sum mismatch %g %g; min %g max %g\n", type_name,
160 static_cast<double>(sum_), static_cast<double>(other.sum_),
161 static_cast<double>(min_), static_cast<double>(max_));
162 }
163
164 return true;
165 }
166
167 private:
168 T min_ = hwy::HighestValue<T>();
169 T max_ = hwy::LowestValue<T>();
170 uint64_t sum_ = 0;
171 size_t count_ = 0;
172};
173
174enum class Algo {
175#if HAVE_INTEL
176 kIntel,
177#endif
178#if HAVE_AVX2SORT
179 kSEA,
180#endif
181#if HAVE_IPS4O
182 kIPS4O,
183#endif
184#if HAVE_PARALLEL_IPS4O
185 kParallelIPS4O,
186#endif
187#if HAVE_PDQSORT
188 kPDQ,
189#endif
190#if HAVE_SORT512
191 kSort512,
192#endif
193#if HAVE_VXSORT
194 kVXSort,
195#endif
196 kStdSort,
199 kVQSort,
201 kVQSelect,
202 kHeapSort,
205};
206
207static inline const char* AlgoName(Algo algo) {
208 switch (algo) {
209#if HAVE_INTEL
210 case Algo::kIntel:
211 return "intel";
212#endif
213#if HAVE_AVX2SORT
214 case Algo::kSEA:
215 return "sea";
216#endif
217#if HAVE_IPS4O
218 case Algo::kIPS4O:
219 return "ips4o";
220#endif
221#if HAVE_PARALLEL_IPS4O
222 case Algo::kParallelIPS4O:
223 return "par_ips4o";
224#endif
225#if HAVE_PDQSORT
226 case Algo::kPDQ:
227 return "pdq";
228#endif
229#if HAVE_SORT512
230 case Algo::kSort512:
231 return "sort512";
232#endif
233#if HAVE_VXSORT
234 case Algo::kVXSort:
235 return "vxsort";
236#endif
237 case Algo::kStdSort:
239 case Algo::kStdSelect:
240 return "std";
241 case Algo::kVQSort:
243 case Algo::kVQSelect:
244 return "vq";
245 case Algo::kHeapSort:
247 return "heapsort";
249 return "heapselect";
250 }
251 return "unreachable";
252}
253
254} // namespace hwy
255#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
256
257// Per-target
258#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == defined(HWY_TARGET_TOGGLE)
259#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
260#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
261#else
262#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
263#endif
264
268#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort
269
271
272// Requires target pragma set by HWY_BEFORE_NAMESPACE
273#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3
274// #include "avx512-16bit-qsort.hpp" // requires vbmi2
275#include "avx512-32bit-qsort.hpp"
276#include "avx512-64bit-qsort.hpp"
277#endif
278
279namespace hwy {
280namespace HWY_NAMESPACE {
281
282#if HAVE_INTEL || HAVE_VXSORT // only supports ascending order
283template <typename T>
284using OtherOrder = detail::OrderAscending<T>;
285#else
286template <typename T>
288#endif
289
291 static HWY_INLINE uint64_t SplitMix64(uint64_t z) {
292 z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
293 z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
294 return z ^ (z >> 31);
295 }
296
297 public:
298 // Generates two vectors of 64-bit seeds via SplitMix64 and stores into
299 // `seeds`. Generating these afresh in each ChoosePivot is too expensive.
300 template <class DU64>
301 static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) {
302 seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull);
303 for (size_t i = 1; i < 2 * Lanes(du64); ++i) {
304 seeds[i] = SplitMix64(seeds[i - 1]);
305 }
306 }
307
308 // Need to pass in the state because vector cannot be class members.
309 template <class VU64>
310 static VU64 RandomBits(VU64& state0, VU64& state1) {
311 VU64 s1 = state0;
312 VU64 s0 = state1;
313 const VU64 bits = Add(s1, s0);
314 state0 = s0;
315 s1 = Xor(s1, ShiftLeft<23>(s1));
316 state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0))));
317 return bits;
318 }
319};
320
321template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)>
322Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) {
323 const VU64 bits = Xorshift128Plus::RandomBits(s0, s1);
324 return BitCast(d, And(bits, mask));
325}
326
327// It is important to avoid denormals, which are flushed to zero by SIMD but not
328// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD.
329template <class DF, class VU64, HWY_IF_FLOAT_D(DF)>
330Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) {
331 using TF = TFromD<DF>;
332 const RebindToUnsigned<decltype(df)> du;
333 using VU = Vec<decltype(du)>;
334
335 const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask);
336
337#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types
338 using TU = MakeUnsigned<TF>;
339 const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>()));
340#else
341 const VU bits = BitCast(du, bits64);
342#endif
343 // Avoid NaN/denormal by only generating values in [1, 2), i.e. random
344 // mantissas with the exponent taken from the representation of 1.0.
345 const VU k1 = BitCast(du, Set(df, TF{1.0}));
346 const VU mantissa_mask = Set(du, MantissaMask<TF>());
347 const VU representation = OrAnd(k1, bits, mantissa_mask);
348 return BitCast(df, representation);
349}
350
351template <class DU64>
352Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) {
353 switch (sizeof_t) {
354 case 2:
355 return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull
356 : 0xFFFFFFFFFFFFFFFFull);
357 case 4:
358 return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull
359 : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull
360 : 0xFFFFFFFFFFFFFFFFull);
361 case 8:
362 return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull
363 : (dist == Dist::kUniform16) ? 0x000000000000FFFFull
364 : 0x00000000FFFFFFFFull);
365 default:
366 HWY_ABORT("Logic error");
367 return Zero(du64);
368 }
369}
370
371template <typename T>
372InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) {
374 using VU64 = Vec<decltype(du64)>;
375 const size_t N64 = Lanes(du64);
376 auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64);
377 Xorshift128Plus::GenerateSeeds(du64, seeds.get());
378 VU64 s0 = Load(du64, seeds.get());
379 VU64 s1 = Load(du64, seeds.get() + N64);
380
381#if HWY_TARGET == HWY_SCALAR
382 const Sisd<T> d;
383#else
384 const Repartition<T, decltype(du64)> d;
385#endif
386 using V = Vec<decltype(d)>;
387 const size_t N = Lanes(d);
388 const VU64 mask = MaskForDist(du64, dist, sizeof(T));
389 auto buf = hwy::AllocateAligned<T>(N);
390
391 size_t i = 0;
392 for (; i + N <= num; i += N) {
393 const V values = RandomValues(d, s0, s1, mask);
394 StoreU(values, d, v + i);
395 }
396 if (i < num) {
397 const V values = RandomValues(d, s0, s1, mask);
398 StoreU(values, d, buf.get());
399 CopyBytes(buf.get(), v + i, (num - i) * sizeof(T));
400 }
401
402 InputStats<T> input_stats;
403 for (size_t i = 0; i < num; ++i) {
404 input_stats.Notify(v[i]);
405 }
406 return input_stats;
407}
408
410#if HAVE_PARALLEL_IPS4O
411 const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a
412 ips4o::StdThreadPool pool{static_cast<int>(
413 HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))};
414#endif
415};
416
417// Bridge from keys (passed to Run) to lanes as expected by HeapPartialSort. For
418// non-128-bit keys they are the same:
419template <class Order, typename KeyType, HWY_IF_NOT_T_SIZE(KeyType, 16)>
420void CallHeapPartialSort(KeyType* HWY_RESTRICT keys, const size_t num_keys,
421 const size_t k) {
423 using detail::TraitsLane;
424 if (Order().IsAscending()) {
425 const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
426 return detail::HeapPartialSort(st, keys, num_keys, k);
427 } else {
428 const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
429 return detail::HeapPartialSort(st, keys, num_keys, k);
430 }
431}
432
433#if VQSORT_ENABLED
434template <class Order>
436 const size_t num_keys, const size_t k) {
437 using detail::SharedTraits;
438 using detail::Traits128;
439 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
440 const size_t num_lanes = num_keys * 2;
441 if (Order().IsAscending()) {
442 const SharedTraits<Traits128<detail::OrderAscending128>> st;
443 return detail::HeapPartialSort(st, lanes, num_lanes, k);
444 } else {
445 const SharedTraits<Traits128<detail::OrderDescending128>> st;
446 return detail::HeapPartialSort(st, lanes, num_lanes, k);
447 }
448}
449
450template <class Order>
451void CallHeapPartialSort(K64V64* HWY_RESTRICT keys, const size_t num_keys,
452 const size_t k) {
453 using detail::SharedTraits;
454 using detail::Traits128;
455 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
456 const size_t num_lanes = num_keys * 2;
457 if (Order().IsAscending()) {
458 const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
459 return detail::HeapPartialSort(st, lanes, num_lanes, k);
460 } else {
461 const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
462 return detail::HeapPartialSort(st, lanes, num_lanes, k);
463 }
464}
465
466template <class Order>
467void CallHeapPartialSort(K32V32* HWY_RESTRICT keys, const size_t num_keys,
468 const size_t k) {
469 using detail::SharedTraits;
470 using detail::TraitsLane;
471 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
472 const size_t num_lanes = num_keys;
473 if (Order().IsAscending()) {
474 const SharedTraits<TraitsLane<detail::OrderAscendingKV64>> st;
475 return detail::HeapPartialSort(st, lanes, num_lanes, k);
476 } else {
477 const SharedTraits<TraitsLane<detail::OrderDescendingKV64>> st;
478 return detail::HeapPartialSort(st, lanes, num_lanes, k);
479 }
480}
481
482#endif // VQSORT_ENABLED
483
484// Bridge from keys (passed to Run) to lanes as expected by HeapSelect. For
485// non-128-bit keys they are the same:
486template <class Order, typename KeyType, HWY_IF_NOT_T_SIZE(KeyType, 16)>
487void CallHeapSelect(KeyType* HWY_RESTRICT keys, const size_t num_keys,
488 const size_t k) {
490 using detail::TraitsLane;
491 if (Order().IsAscending()) {
492 const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
493 return detail::HeapSelect(st, keys, num_keys, k);
494 } else {
495 const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
496 return detail::HeapSelect(st, keys, num_keys, k);
497 }
498}
499
500#if VQSORT_ENABLED
501template <class Order>
502void CallHeapSelect(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys,
503 const size_t k) {
504 using detail::SharedTraits;
505 using detail::Traits128;
506 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
507 const size_t num_lanes = num_keys * 2;
508 if (Order().IsAscending()) {
509 const SharedTraits<Traits128<detail::OrderAscending128>> st;
510 return detail::HeapSelect(st, lanes, num_lanes, k);
511 } else {
512 const SharedTraits<Traits128<detail::OrderDescending128>> st;
513 return detail::HeapSelect(st, lanes, num_lanes, k);
514 }
515}
516
517template <class Order>
518void CallHeapSelect(K64V64* HWY_RESTRICT keys, const size_t num_keys,
519 const size_t k) {
520 using detail::SharedTraits;
521 using detail::Traits128;
522 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
523 const size_t num_lanes = num_keys * 2;
524 if (Order().IsAscending()) {
525 const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
526 return detail::HeapSelect(st, lanes, num_lanes, k);
527 } else {
528 const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
529 return detail::HeapSelect(st, lanes, num_lanes, k);
530 }
531}
532
533template <class Order>
534void CallHeapSelect(K32V32* HWY_RESTRICT keys, const size_t num_keys,
535 const size_t k) {
536 using detail::SharedTraits;
537 using detail::TraitsLane;
538 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
539 const size_t num_lanes = num_keys;
540 if (Order().IsAscending()) {
541 const SharedTraits<TraitsLane<detail::OrderAscendingKV64>> st;
542 return detail::HeapSelect(st, lanes, num_lanes, k);
543 } else {
544 const SharedTraits<TraitsLane<detail::OrderDescendingKV64>> st;
545 return detail::HeapSelect(st, lanes, num_lanes, k);
546 }
547}
548
549#endif // VQSORT_ENABLED
550
551// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For
552// non-128-bit keys they are the same:
553template <class Order, typename KeyType, HWY_IF_NOT_T_SIZE(KeyType, 16)>
554void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) {
556 using detail::TraitsLane;
557 if (Order().IsAscending()) {
558 const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
559 return detail::HeapSort(st, keys, num_keys);
560 } else {
561 const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
562 return detail::HeapSort(st, keys, num_keys);
563 }
564}
565
566#if VQSORT_ENABLED
567template <class Order>
568void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) {
569 using detail::SharedTraits;
570 using detail::Traits128;
571 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
572 const size_t num_lanes = num_keys * 2;
573 if (Order().IsAscending()) {
574 const SharedTraits<Traits128<detail::OrderAscending128>> st;
575 return detail::HeapSort(st, lanes, num_lanes);
576 } else {
577 const SharedTraits<Traits128<detail::OrderDescending128>> st;
578 return detail::HeapSort(st, lanes, num_lanes);
579 }
580}
581
582template <class Order>
583void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) {
584 using detail::SharedTraits;
585 using detail::Traits128;
586 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
587 const size_t num_lanes = num_keys * 2;
588 if (Order().IsAscending()) {
589 const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
590 return detail::HeapSort(st, lanes, num_lanes);
591 } else {
592 const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
593 return detail::HeapSort(st, lanes, num_lanes);
594 }
595}
596
597template <class Order>
598void CallHeapSort(K32V32* HWY_RESTRICT keys, const size_t num_keys) {
599 using detail::SharedTraits;
600 using detail::TraitsLane;
601 uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
602 const size_t num_lanes = num_keys;
603 if (Order().IsAscending()) {
604 const SharedTraits<TraitsLane<detail::OrderAscendingKV64>> st;
605 return detail::HeapSort(st, lanes, num_lanes);
606 } else {
607 const SharedTraits<TraitsLane<detail::OrderDescendingKV64>> st;
608 return detail::HeapSort(st, lanes, num_lanes);
609 }
610}
611
612#endif // VQSORT_ENABLED
613
614template <class Order, typename KeyType>
615void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num,
616 SharedState& shared, size_t /*thread*/, size_t k = 0) {
617 const std::less<KeyType> less;
618 const std::greater<KeyType> greater;
619
620#if !HAVE_PARALLEL_IPS4O
621 (void)shared;
622#endif
623
624 switch (algo) {
625#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3
626 case Algo::kIntel:
627 return avx512_qsort<KeyType>(inout, static_cast<int64_t>(num));
628#endif
629
630#if HAVE_AVX2SORT
631 case Algo::kSEA:
632 return avx2::quicksort(inout, static_cast<int>(num));
633#endif
634
635#if HAVE_IPS4O
636 case Algo::kIPS4O:
637 if (Order().IsAscending()) {
638 return ips4o::sort(inout, inout + num, less);
639 } else {
640 return ips4o::sort(inout, inout + num, greater);
641 }
642#endif
643
644#if HAVE_PARALLEL_IPS4O
645 case Algo::kParallelIPS4O:
646 if (Order().IsAscending()) {
647 return ips4o::parallel::sort(inout, inout + num, less, shared.pool);
648 } else {
649 return ips4o::parallel::sort(inout, inout + num, greater, shared.pool);
650 }
651#endif
652
653#if HAVE_SORT512
654 case Algo::kSort512:
655 HWY_ABORT("not supported");
656 // return Sort512::Sort(inout, num);
657#endif
658
659#if HAVE_PDQSORT
660 case Algo::kPDQ:
661 if (Order().IsAscending()) {
662 return boost::sort::pdqsort_branchless(inout, inout + num, less);
663 } else {
664 return boost::sort::pdqsort_branchless(inout, inout + num, greater);
665 }
666#endif
667
668#if HAVE_VXSORT
669 case Algo::kVXSort: {
670#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
671 (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
672 fprintf(stderr, "Do not call for target %s\n",
674 return;
675#else
676#if VXSORT_AVX3
677 vxsort::vxsort<KeyType, vxsort::AVX512> vx;
678#else
679 vxsort::vxsort<KeyType, vxsort::AVX2> vx;
680#endif
681 if (Order().IsAscending()) {
682 return vx.sort(inout, inout + num - 1);
683 } else {
684 fprintf(stderr, "Skipping VX - does not support descending order\n");
685 return;
686 }
687#endif // enabled for this target
688 }
689#endif // HAVE_VXSORT
690
691 case Algo::kStdSort:
692 if (Order().IsAscending()) {
693 return std::sort(inout, inout + num, less);
694 } else {
695 return std::sort(inout, inout + num, greater);
696 }
697
699 if (Order().IsAscending()) {
700 return std::partial_sort(inout, inout + k, inout + num, less);
701 } else {
702 return std::partial_sort(inout, inout + k, inout + num, greater);
703 }
704
705 case Algo::kStdSelect:
706 if (Order().IsAscending()) {
707 return std::nth_element(inout, inout + k, inout + num, less);
708 } else {
709 return std::nth_element(inout, inout + k, inout + num, greater);
710 }
711
712 case Algo::kVQSort:
713 return VQSort(inout, num, Order());
714
716 return VQPartialSort(inout, num, k, Order());
717
718 case Algo::kVQSelect:
719 return VQSelect(inout, num, k, Order());
720
721 case Algo::kHeapSort:
722 return CallHeapSort<Order>(inout, num);
723
725 return CallHeapPartialSort<Order>(inout, num, k);
726
728 return CallHeapSelect<Order>(inout, num, k);
729
730 default:
731 HWY_ABORT("Not implemented");
732 }
733}
734
735// NOLINTNEXTLINE(google-readability-namespace-comments)
736} // namespace HWY_NAMESPACE
737} // namespace hwy
739
740#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
#define HWY_RESTRICT
Definition base.h:95
#define HWY_POP_ATTRIBUTES
Definition base.h:165
#define HWY_MIN(a, b)
Definition base.h:176
#define HWY_ABORT(format,...)
Definition base.h:233
#define HWY_INLINE
Definition base.h:101
#define HWY_PUSH_ATTRIBUTES(targets_str)
Definition base.h:164
Definition algo-inl.h:290
static void GenerateSeeds(DU64 du64, TFromD< DU64 > *HWY_RESTRICT seeds)
Definition algo-inl.h:301
static HWY_INLINE uint64_t SplitMix64(uint64_t z)
Definition algo-inl.h:291
static VU64 RandomBits(VU64 &state0, VU64 &state1)
Definition algo-inl.h:310
Definition algo-inl.h:126
T min_
Definition algo-inl.h:168
size_t count_
Definition algo-inl.h:171
T max_
Definition algo-inl.h:169
bool operator==(const InputStats &other) const
Definition algo-inl.h:141
void Notify(T value)
Definition algo-inl.h:128
uint64_t sum_
Definition algo-inl.h:170
#define HWY_TARGET
Definition detect_targets.h:543
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition vqsort-inl.h:159
void HeapSelect(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, const size_t select)
Definition vqsort-inl.h:179
void HeapPartialSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, const size_t select)
Definition vqsort-inl.h:208
InputStats< T > GenerateInput(const Dist dist, T *v, size_t num)
Definition algo-inl.h:372
void CallHeapSort(KeyType *HWY_RESTRICT keys, const size_t num_keys)
Definition algo-inl.h:554
D d
Definition arm_sve-inl.h:1915
HWY_API VFromD< D > BitCast(D d, Vec128< FromT, Repartition< FromT, D >().MaxLanes()> v)
Definition arm_neon-inl.h:1581
HWY_API Vec128< T, N > And(const Vec128< T, N > a, const Vec128< T, N > b)
Definition arm_neon-inl.h:2690
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
HWY_API void StoreU(Vec128< uint8_t > v, D, uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3689
typename D::T TFromD
Definition ops/shared-inl.h:426
HWY_API VFromD< D > Load(D d, const TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3664
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
Vec< DU64 > MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t)
Definition algo-inl.h:352
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition arm_neon-inl.h:2739
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition ops/shared-inl.h:465
void Run(Algo algo, KeyType *HWY_RESTRICT inout, size_t num, SharedState &shared, size_t, size_t k=0)
Definition algo-inl.h:615
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition arm_neon-inl.h:2779
ScalableTag< T, -1 > SortTag
Definition contrib/sort/shared-inl.h:146
void CallHeapSelect(KeyType *HWY_RESTRICT keys, const size_t num_keys, const size_t k)
Definition algo-inl.h:487
HWY_INLINE Vec128< TFromD< D > > Set(D, T t)
Definition arm_neon-inl.h:931
detail::OrderDescending< T > OtherOrder
Definition algo-inl.h:287
HWY_API TFromV< V > GetLane(const V v)
Definition arm_neon-inl.h:1648
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:46
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:471
Vec< D > RandomValues(D d, VU64 &s0, VU64 &s1, const VU64 mask)
Definition algo-inl.h:322
void CallHeapPartialSort(KeyType *HWY_RESTRICT keys, const size_t num_keys, const size_t k)
Definition algo-inl.h:420
HWY_DLLEXPORT void TypeName(const TypeInfo &info, size_t N, char *string100)
Definition abort.h:8
HWY_API void CopyBytes(const From *from, To *to)
Definition base.h:327
static const char * DistName(Dist dist)
Definition algo-inl.h:113
Dist
Definition algo-inl.h:107
static std::vector< Dist > AllDist()
Definition algo-inl.h:109
static const char * AlgoName(Algo algo)
Definition algo-inl.h:207
typename detail::Relations< T >::Unsigned MakeUnsigned
Definition base.h:2078
static HWY_MAYBE_UNUSED const char * TargetName(int64_t target)
Definition targets.h:85
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t *HWY_RESTRICT keys, const size_t n, const size_t k, SortAscending)
HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t *HWY_RESTRICT keys, const size_t n, SortAscending)
Algo
Definition algo-inl.h:174
@ kHeapPartialSort
HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t *HWY_RESTRICT keys, const size_t n, const size_t k, SortAscending)
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
Definition algo-inl.h:409
Definition ops/shared-inl.h:198
Definition sorting_networks-inl.h:893
Definition traits-inl.h:652
Definition base.h:412