Grok 12.0.1
random-inl.h
Go to the documentation of this file.
1/*
2 * Original implementation written in 2019
3 * by David Blackman and Sebastiano Vigna (vigna@acm.org)
4 * Available at https://prng.di.unimi.it/ with creative commons license:
5 * To the extent possible under law, the author has dedicated all copyright
6 * and related and neighboring rights to this software to the public domain
7 * worldwide. This software is distributed without any warranty.
8 * See <http://creativecommons.org/publicdomain/zero/1.0/>.
9 *
10 * This implementation is a Vector port of the original implementation
11 * written by Marco Barbone (m.barbone19@imperial.ac.uk).
12 * I take no credit for the original implementation.
13 * The code is provided as is and the original license applies.
14 */
15
16#if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \
17 defined(HWY_TARGET_TOGGLE) // NOLINT
18#ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
19#undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
20#else
21#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
22#endif
23
24#include <array>
25#include <cstdint>
26#include <limits>
27
29#include "hwy/highway.h"
30
31HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR
32
33namespace hwy {
34
35namespace HWY_NAMESPACE { // required: unique per target
36namespace internal {
37
38namespace {
39#if HWY_HAVE_FLOAT64
40// C++ < 17 does not support hexfloat
41#if __cpp_hex_float > 201603L
42constexpr double kMulConst = 0x1.0p-53;
43#else
44constexpr double kMulConst =
45 0.00000000000000011102230246251565404236316680908203125;
46#endif // __cpp_hex_float
47
48#endif // HWY_HAVE_FLOAT64
49
50constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c,
51 0xa9582618e03fc9aa, 0x39abdc4529b1661c};
52
53constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3,
54 0x77710069854ee241, 0x39109bb02acbe635};
55} // namespace
56
58 public:
59 constexpr explicit SplitMix64(const std::uint64_t state) noexcept
60 : state_(state) {}
61
63 std::uint64_t z = (state_ += 0x9e3779b97f4a7c15);
64 z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
65 z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
66 return z ^ (z >> 31);
67 }
68
69 private:
70 std::uint64_t state_;
71};
72
73class Xoshiro {
74 public:
75 HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept
76 : state_{} {
77 SplitMix64 splitMix64{seed};
78 for (auto &element : state_) {
79 element = splitMix64();
80 }
81 }
82
83 HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed,
84 const std::uint64_t thread_id) noexcept
85 : Xoshiro(seed) {
86 for (auto i = UINT64_C(0); i < thread_id; ++i) {
87 Jump();
88 }
89 }
90
91 HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); }
92
93#if HWY_HAVE_FLOAT64
94 HWY_CXX14_CONSTEXPR double Uniform() noexcept {
95 return static_cast<double>(Next() >> 11) * kMulConst;
96 }
97#endif
98
99 HWY_CXX14_CONSTEXPR std::array<std::uint64_t, 4> GetState() const {
100 return {state_[0], state_[1], state_[2], state_[3]};
101 }
102
104 std::array<std::uint64_t, 4> state) noexcept {
105 state_[0] = state[0];
106 state_[1] = state[1];
107 state_[2] = state[2];
108 state_[3] = state[3];
109 }
110
111 static constexpr std::uint64_t StateSize() noexcept { return 4; }
112
113 /* This is the jump function for the generator. It is equivalent to 2^128
114 * calls to next(); it can be used to generate 2^128 non-overlapping
115 * subsequences for parallel computations. */
116 HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); }
117
118 /* This is the long-jump function for the generator. It is equivalent to 2^192
119 * calls to next(); it can be used to generate 2^64 starting points, from each
120 * of which jump() will generate 2^64 non-overlapping subsequences for
121 * parallel distributed computations. */
122 HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); }
123
124 private:
125 std::uint64_t state_[4];
126
127 static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept {
128 return (x << k) | (x >> (64 - k));
129 }
130
131 HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept {
132 const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0];
133 const std::uint64_t t = state_[1] << 17;
134
135 state_[2] ^= state_[0];
136 state_[3] ^= state_[1];
137 state_[1] ^= state_[2];
138 state_[0] ^= state_[3];
139
140 state_[2] ^= t;
141
142 state_[3] = Rotl(state_[3], 45);
143
144 return result;
145 }
146
147 HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept {
148 std::uint64_t s0 = 0;
149 std::uint64_t s1 = 0;
150 std::uint64_t s2 = 0;
151 std::uint64_t s3 = 0;
152
153 for (const std::uint64_t i : jumpArray)
154 for (std::uint_fast8_t b = 0; b < 64; b++) {
155 if (i & std::uint64_t{1UL} << b) {
156 s0 ^= state_[0];
157 s1 ^= state_[1];
158 s2 ^= state_[2];
159 s3 ^= state_[3];
160 }
161 Next();
162 }
163
164 state_[0] = s0;
165 state_[1] = s1;
166 state_[2] = s2;
167 state_[3] = s3;
168 }
169};
170
171} // namespace internal
172
174 private:
177#if HWY_HAVE_FLOAT64
178 using VF64 = Vec<ScalableTag<double>>;
179#endif
180 public:
181 explicit VectorXoshiro(const std::uint64_t seed,
182 const std::uint64_t threadNumber = 0)
185 streams{state_.shape().back()} {
186 internal::Xoshiro xoshiro{seed};
187
188 for (std::uint64_t i = 0; i < threadNumber; ++i) {
189 xoshiro.LongJump();
190 }
191
192 for (size_t i = 0UL; i < streams; ++i) {
193 const auto state = xoshiro.GetState();
194 for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) {
195 state_[{j}][i] = state[j];
196 }
197 xoshiro.Jump();
198 }
199 }
200
201 HWY_INLINE VU64 operator()() noexcept { return Next(); }
202
205 const ScalableTag<std::uint64_t> tag{};
206 auto s0 = Load(tag, state_[{0}].data());
207 auto s1 = Load(tag, state_[{1}].data());
208 auto s2 = Load(tag, state_[{2}].data());
209 auto s3 = Load(tag, state_[{3}].data());
210 for (std::uint64_t i = 0; i < n; i += Lanes(tag)) {
211 const auto next = Update(s0, s1, s2, s3);
212 Store(next, tag, result.data() + i);
213 }
214 Store(s0, tag, state_[{0}].data());
215 Store(s1, tag, state_[{1}].data());
216 Store(s2, tag, state_[{2}].data());
217 Store(s3, tag, state_[{3}].data());
218 return result;
219 }
220
221 template <std::uint64_t N>
222 std::array<std::uint64_t, N> operator()() noexcept {
223 alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
224 const ScalableTag<std::uint64_t> tag{};
225 auto s0 = Load(tag, state_[{0}].data());
226 auto s1 = Load(tag, state_[{1}].data());
227 auto s2 = Load(tag, state_[{2}].data());
228 auto s3 = Load(tag, state_[{3}].data());
229 for (std::uint64_t i = 0; i < N; i += Lanes(tag)) {
230 const auto next = Update(s0, s1, s2, s3);
231 Store(next, tag, result.data() + i);
232 }
233 Store(s0, tag, state_[{0}].data());
234 Store(s1, tag, state_[{1}].data());
235 Store(s2, tag, state_[{2}].data());
236 Store(s3, tag, state_[{3}].data());
237 return result;
238 }
239
240 std::uint64_t StateSize() const noexcept {
241 return streams * internal::Xoshiro::StateSize();
242 }
243
244 const StateType &GetState() const { return state_; }
245
246#if HWY_HAVE_FLOAT64
247
248 HWY_INLINE VF64 Uniform() noexcept {
249 const ScalableTag<double> real_tag{};
250 const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
251 const auto bits = ShiftRight<11>(Next());
252 const auto real = ConvertTo(real_tag, bits);
253 return Mul(real, MUL_VALUE);
254 }
255
256 AlignedVector<double> Uniform(const std::size_t n) {
257 AlignedVector<double> result(n);
258 const ScalableTag<std::uint64_t> tag{};
259 const ScalableTag<double> real_tag{};
260 const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
261
262 auto s0 = Load(tag, state_[{0}].data());
263 auto s1 = Load(tag, state_[{1}].data());
264 auto s2 = Load(tag, state_[{2}].data());
265 auto s3 = Load(tag, state_[{3}].data());
266
267 for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) {
268 const auto next = Update(s0, s1, s2, s3);
269 const auto bits = ShiftRight<11>(next);
270 const auto real = ConvertTo(real_tag, bits);
271 const auto uniform = Mul(real, MUL_VALUE);
272 Store(uniform, real_tag, result.data() + i);
273 }
274
275 Store(s0, tag, state_[{0}].data());
276 Store(s1, tag, state_[{1}].data());
277 Store(s2, tag, state_[{2}].data());
278 Store(s3, tag, state_[{3}].data());
279 return result;
280 }
281
282 template <std::uint64_t N>
283 std::array<double, N> Uniform() noexcept {
284 alignas(HWY_ALIGNMENT) std::array<double, N> result;
285 const ScalableTag<std::uint64_t> tag{};
286 const ScalableTag<double> real_tag{};
287 const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
288
289 auto s0 = Load(tag, state_[{0}].data());
290 auto s1 = Load(tag, state_[{1}].data());
291 auto s2 = Load(tag, state_[{2}].data());
292 auto s3 = Load(tag, state_[{3}].data());
293
294 for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) {
295 const auto next = Update(s0, s1, s2, s3);
296 const auto bits = ShiftRight<11>(next);
297 const auto real = ConvertTo(real_tag, bits);
298 const auto uniform = Mul(real, MUL_VALUE);
299 Store(uniform, real_tag, result.data() + i);
300 }
301
302 Store(s0, tag, state_[{0}].data());
303 Store(s1, tag, state_[{1}].data());
304 Store(s2, tag, state_[{2}].data());
305 Store(s3, tag, state_[{3}].data());
306 return result;
307 }
308
309#endif
310
311 private:
313 const std::uint64_t streams;
314
315 HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2,
316 VU64 &s3) noexcept {
317 const auto result = Add(RotateRight<41>(Add(s0, s3)), s0);
318 const auto t = ShiftLeft<17>(s1);
319 s2 = Xor(s2, s0);
320 s3 = Xor(s3, s1);
321 s1 = Xor(s1, s2);
322 s0 = Xor(s0, s3);
323 s2 = Xor(s2, t);
324 s3 = RotateRight<19>(s3);
325 return result;
326 }
327
328 HWY_INLINE VU64 Next() noexcept {
329 const ScalableTag<std::uint64_t> tag{};
330 auto s0 = Load(tag, state_[{0}].data());
331 auto s1 = Load(tag, state_[{1}].data());
332 auto s2 = Load(tag, state_[{2}].data());
333 auto s3 = Load(tag, state_[{3}].data());
334 auto result = Update(s0, s1, s2, s3);
335 Store(s0, tag, state_[{0}].data());
336 Store(s1, tag, state_[{1}].data());
337 Store(s2, tag, state_[{2}].data());
338 Store(s3, tag, state_[{3}].data());
339 return result;
340 }
341};
342
343template <std::uint64_t size = 1024>
345 public:
346 using result_type = std::uint64_t;
347
348 static constexpr result_type(min)() {
349 return (std::numeric_limits<result_type>::min)();
350 }
351
352 static constexpr result_type(max)() {
353 return (std::numeric_limits<result_type>::max)();
354 }
355
356 explicit CachedXoshiro(const result_type seed,
357 const result_type threadNumber = 0)
358 : generator_{seed, threadNumber},
359 cache_{generator_.operator()<size>()},
360 index_{0} {}
361
363 if (HWY_UNLIKELY(index_ == size)) {
364 cache_ = std::move(generator_.operator()<size>());
365 index_ = 0;
366 }
367 return cache_[index_++];
368 }
369
370 private:
372 alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
373 std::size_t index_;
374
375 static_assert((size & (size - 1)) == 0 && size != 0,
376 "only power of 2 are supported");
377};
378
379} // namespace HWY_NAMESPACE
380} // namespace hwy
381
383
384#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_
#define HWY_ALIGNMENT
Definition aligned_allocator.h:41
#define HWY_CXX14_CONSTEXPR
Definition base.h:304
#define HWY_INLINE
Definition base.h:101
#define HWY_CXX17_CONSTEXPR
Definition base.h:299
#define HWY_UNLIKELY(expr)
Definition base.h:107
const std::array< size_t, axes > & shape() const
Definition aligned_allocator.h:351
T * data()
Definition aligned_allocator.h:366
Definition random-inl.h:344
std::size_t index_
Definition random-inl.h:373
std::uint64_t result_type
Definition random-inl.h:346
result_type operator()() noexcept
Definition random-inl.h:362
CachedXoshiro(const result_type seed, const result_type threadNumber=0)
Definition random-inl.h:356
VectorXoshiro generator_
Definition random-inl.h:371
Definition random-inl.h:173
const std::uint64_t streams
Definition random-inl.h:313
const StateType & GetState() const
Definition random-inl.h:244
static HWY_INLINE VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2, VU64 &s3) noexcept
Definition random-inl.h:315
AlignedVector< std::uint64_t > operator()(const std::size_t n)
Definition random-inl.h:203
StateType state_
Definition random-inl.h:312
Vec< ScalableTag< std::uint64_t > > VU64
Definition random-inl.h:175
HWY_INLINE VU64 Next() noexcept
Definition random-inl.h:328
HWY_INLINE VU64 operator()() noexcept
Definition random-inl.h:201
VectorXoshiro(const std::uint64_t seed, const std::uint64_t threadNumber=0)
Definition random-inl.h:181
std::array< std::uint64_t, N > operator()() noexcept
Definition random-inl.h:222
std::uint64_t StateSize() const noexcept
Definition random-inl.h:240
constexpr SplitMix64(const std::uint64_t state) noexcept
Definition random-inl.h:59
HWY_CXX14_CONSTEXPR std::uint64_t operator()()
Definition random-inl.h:62
std::uint64_t state_
Definition random-inl.h:70
Definition random-inl.h:73
std::uint64_t state_[4]
Definition random-inl.h:125
HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept
Definition random-inl.h:131
HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept
Definition random-inl.h:91
HWY_CXX14_CONSTEXPR std::array< std::uint64_t, 4 > GetState() const
Definition random-inl.h:99
HWY_CXX14_CONSTEXPR Xoshiro(const std::uint64_t seed) noexcept
Definition random-inl.h:75
static constexpr std::uint64_t StateSize() noexcept
Definition random-inl.h:111
static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept
Definition random-inl.h:127
HWY_CXX14_CONSTEXPR void Jump() noexcept
Definition random-inl.h:116
HWY_CXX14_CONSTEXPR void LongJump() noexcept
Definition random-inl.h:122
HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t(&jumpArray)[4]) noexcept
Definition random-inl.h:147
HWY_CXX17_CONSTEXPR void SetState(std::array< std::uint64_t, 4 > state) noexcept
Definition random-inl.h:103
HWY_CXX14_CONSTEXPR Xoshiro(const std::uint64_t seed, const std::uint64_t thread_id) noexcept
Definition random-inl.h:83
HWY_API void Store(VFromD< D > v, D d, TFromD< D > *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:3911
HWY_API Vec128< float > ConvertTo(D, Vec128< int32_t > v)
Definition arm_neon-inl.h:3971
HWY_API VFromD< D > Load(D d, const TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3664
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition arm_neon-inl.h:2739
typename detail::ScalableTagChecker< T, kPow2 >::type ScalableTag
Definition ops/shared-inl.h:367
HWY_INLINE Vec128< TFromD< D > > Set(D, T t)
Definition arm_neon-inl.h:931
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:46
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API V Mul(V a, V b)
Definition generic_ops-inl.h:7309
Definition abort.h:8
std::vector< T, AlignedAllocator< T > > AlignedVector
Definition aligned_allocator.h:172
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
#define HWY_NAMESPACE
Definition set_macros-inl.h:166