Grok 12.0.1
matvec-inl.h
Go to the documentation of this file.
1// Copyright 2023 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Include guard (still compiled once per target)
17#if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \
18 defined(HWY_TARGET_TOGGLE)
19#ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
20#undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
21#else
22#define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
23#endif
24
25#include "hwy/cache_control.h"
27#include "hwy/highway.h"
28
30namespace hwy {
31namespace HWY_NAMESPACE {
32
33template <typename TA, typename TB>
34TA AddScalar(TA a, TB b) {
35 return ConvertScalarTo<TA>(ConvertScalarTo<float>(a) +
36 ConvertScalarTo<float>(b));
37}
38
39template <size_t kOuter, size_t kInner, typename T, bool kAdd>
41 const T* HWY_RESTRICT vec,
42 const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
43 hwy::ThreadPool& pool) {
44 (void)add;
45
46 // Process multiple rows at a time so that we write multiples of a cache line
47 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
48 // parallelization potential.
49 constexpr size_t kChunkSize = 64 / sizeof(T);
50 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
51
52 const ScalableTag<T> d;
53 const size_t N = Lanes(d);
54 // Required for Stream loop, otherwise we might have partial vectors.
55 HWY_DASSERT(kChunkSize >= N);
56 pool.Run(0, num_chunks,
57 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
58 // MSVC workaround: duplicate to ensure constexpr.
59 constexpr size_t kChunkSize = 64 / sizeof(T);
60 // Software write-combining to avoid cache pollution from out.
61 // Although `out` may be used later, keeping it out of the cache
62 // now and avoiding RFOs is a consistent 5% overall win.
63 HWY_ALIGN T buf[kChunkSize];
64
65 // Only handle entire chunks here because the Stream is not masked.
66 // Remaining rows are handled after the pool.Run.
67 const size_t begin = static_cast<size_t>(chunk * kChunkSize);
68 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
69 auto sum0 = Zero(d);
70 auto sum1 = Zero(d);
71 // 4x unrolling barely helps SKX but likely helps Arm V2.
72 auto sum2 = Zero(d);
73 auto sum3 = Zero(d);
74
75 const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner];
76 size_t i = 0;
77 // No clear win from prefetching from the next 1..3 rows.
78 // clflush &row[i] is slow, clflushopt less so but not helping.
79 HWY_UNROLL(1)
80 for (; i + 4 * N <= kInner; i += 4 * N) {
81 const auto a0 = LoadU(d, row + i + 0 * N);
82 const auto v0 = LoadU(d, vec + i + 0 * N);
83 sum0 = MulAdd(a0, v0, sum0);
84
85 const auto a1 = LoadU(d, row + i + 1 * N);
86 const auto v1 = LoadU(d, vec + i + 1 * N);
87 sum1 = MulAdd(a1, v1, sum1);
88
89 const auto a2 = LoadU(d, row + i + 2 * N);
90 const auto v2 = LoadU(d, vec + i + 2 * N);
91 sum2 = MulAdd(a2, v2, sum2);
92
93 const auto a3 = LoadU(d, row + i + 3 * N);
94 const auto v3 = LoadU(d, vec + i + 3 * N);
95 sum3 = MulAdd(a3, v3, sum3);
96 }
97 // Last entire vectors
98 for (; i + N <= kInner; i += N) {
99 const auto a0 = LoadU(d, row + i);
100 const auto v0 = LoadU(d, vec + i);
101 sum0 = MulAdd(a0, v0, sum0);
102 }
103 const size_t remainder = kInner - i;
104 if (remainder != 0) {
105 const auto a0 = LoadN(d, row + i, remainder);
106 const auto v0 = LoadN(d, vec + i, remainder);
107 sum1 = MulAdd(a0, v0, sum1);
108 }
109 // Reduction tree: sum of all accumulators, then their lanes
110 sum2 = Add(sum2, sum3);
111 sum0 = Add(sum0, sum1);
112 sum0 = Add(sum0, sum2);
113 buf[idx_row] = ReduceSum(d, sum0);
114 HWY_IF_CONSTEXPR(kAdd) {
115 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
116 }
117 } // idx_row
118 HWY_UNROLL(4) // 1..4 iterations
119 for (size_t i = 0; i != kChunkSize; i += N) {
120 Stream(Load(d, buf + i), d, out + begin + i);
121 }
122 });
124
125 // Handle remainder rows which are not a multiple of the chunk size.
126 for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
127 auto sum0 = Zero(d);
128
129 const T* HWY_RESTRICT row = &mat[r * kInner];
130 size_t i = 0;
131 HWY_UNROLL(1)
132 for (; i + N <= kInner; i += N) {
133 const auto a0 = LoadU(d, row + i);
134 const auto v0 = LoadU(d, vec + i);
135 sum0 = MulAdd(a0, v0, sum0);
136 }
137 const size_t remainder = kInner - i;
138 if (remainder != 0) {
139 const auto a0 = LoadN(d, row + i, remainder);
140 const auto v0 = LoadN(d, vec + i, remainder);
141 sum0 = MulAdd(a0, v0, sum0);
142 }
143 out[r] = ReduceSum(d, sum0);
144 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
145 } // r
146}
147
148// Multiplies mat with vec, adds add and puts the result in out.
149//
150// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
151// index i * kInner + j.
152//
153// vec is a (kInner,)-shaped array.
154//
155// add is a (kOuter,)-shaped array.
156//
157// out is a (kOuter,)-shaped array that will set to mat @ vec + add.
158template <size_t kOuter, size_t kInner, typename T>
160 const T* HWY_RESTRICT vec,
161 const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
162 hwy::ThreadPool& pool) {
163 MatVecAddImpl<kOuter, kInner, T, true>(mat, vec, add, out, pool);
164}
165
166// Multiplies mat with vec and puts the result in out.
167//
168// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
169// index i * kInner + j.
170//
171// vec is a (kInner,)-shaped array.
172//
173// out is a (kOuter,)-shaped array that will set to mat @ vec.
174template <size_t kOuter, size_t kInner, typename T>
175HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec,
176 T* HWY_RESTRICT out, hwy::ThreadPool& pool) {
177 MatVecAddImpl<kOuter, kInner, T, false>(mat, vec, /*add=*/nullptr, out, pool);
178}
179
180// This target lacks too many ops required in our implementation, use
181// HWY_EMU128 instead.
182#if HWY_TARGET != HWY_SCALAR
183
184// Specialization for bf16 matrix, which halves memory bandwidth requirements.
185template <size_t kOuter, size_t kInner, bool kAdd>
187 const float* HWY_RESTRICT vec,
188 const float* HWY_RESTRICT add,
189 float* HWY_RESTRICT out,
190 hwy::ThreadPool& pool) {
191 // Process multiple rows at a time so that we write multiples of a cache line
192 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
193 // parallelization potential.
194 constexpr size_t kChunkSize = 64 / sizeof(float);
195 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
196
197 const ScalableTag<float> d;
198 const Repartition<hwy::bfloat16_t, decltype(d)> d16;
199 // In the remainder loop, we only process a single f32 vector, so load half
200 // vectors of bf16 to avoid overrun.
201 const Half<decltype(d16)> d16h;
202 using V = Vec<decltype(d)>;
203 using V16 = Vec<decltype(d16)>;
204 using V16H = Vec<decltype(d16h)>;
205 const size_t N = Lanes(d);
206 // Required for Stream loop, otherwise we might have partial vectors.
207 HWY_DASSERT(kChunkSize >= N);
208 pool.Run(0, num_chunks,
209 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
210 // MSVC workaround: duplicate to ensure constexpr.
211 constexpr size_t kChunkSize = 64 / sizeof(float);
212 // Software write-combining to avoid cache pollution from out.
213 // Although `out` may be used later, keeping it out of the cache
214 // now and avoiding RFOs is a consistent 5% overall win.
215 HWY_ALIGN float buf[kChunkSize];
216
217 // Only handle entire chunks here because the Stream is not masked.
218 // Remaining rows are handled after the pool.Run.
219 const size_t begin = static_cast<size_t>(chunk * kChunkSize);
220 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
221 auto sum0 = Zero(d);
222 auto sum1 = Zero(d);
223 // 4x unrolling barely helps SKX but likely helps Arm V2.
224 auto sum2 = Zero(d);
225 auto sum3 = Zero(d);
226
227 const hwy::bfloat16_t* HWY_RESTRICT row =
228 &mat[(begin + idx_row) * kInner];
229 size_t i = 0;
230 // No clear win from prefetching from the next 1..3 rows.
231 // clflush &row[i] is slow, clflushopt less so but not helping.
232 HWY_UNROLL(1)
233 for (; i + 4 * N <= kInner; i += 4 * N) {
234 const V16 b0 = LoadU(d16, row + i + 0 * N);
235 const V a0 = PromoteLowerTo(d, b0);
236 const V a1 = PromoteUpperTo(d, b0);
237
238 const V16 b1 = LoadU(d16, row + i + 2 * N);
239 const V a2 = PromoteLowerTo(d, b1);
240 const V a3 = PromoteUpperTo(d, b1);
241
242 const V v0 = LoadU(d, vec + i + 0 * N);
243 sum0 = MulAdd(a0, v0, sum0);
244
245 const V v1 = LoadU(d, vec + i + 1 * N);
246 sum1 = MulAdd(a1, v1, sum1);
247
248 const V v2 = LoadU(d, vec + i + 2 * N);
249 sum2 = MulAdd(a2, v2, sum2);
250
251 const V v3 = LoadU(d, vec + i + 3 * N);
252 sum3 = MulAdd(a3, v3, sum3);
253 }
254 // Last entire vectors
255 for (; i + N <= kInner; i += N) {
256 const V16H b0 = LoadU(d16h, row + i);
257 const V a0 = PromoteTo(d, b0);
258 const V v0 = LoadU(d, vec + i);
259 sum0 = MulAdd(a0, v0, sum0);
260 }
261 const size_t remainder = kInner - i;
262 if (remainder != 0) {
263 const V16H b0 = LoadN(d16h, row + i, remainder);
264 const V a0 = PromoteTo(d, b0);
265 const V v0 = LoadN(d, vec + i, remainder);
266 sum1 = MulAdd(a0, v0, sum1);
267 }
268 // Reduction tree: sum of all accumulators, then their lanes
269 sum2 = Add(sum2, sum3);
270 sum0 = Add(sum0, sum1);
271 sum0 = Add(sum0, sum2);
272 buf[idx_row] = ReduceSum(d, sum0);
273 HWY_IF_CONSTEXPR(kAdd) {
274 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
275 }
276 } // idx_row
277 HWY_UNROLL(4) // 1..4 iterations
278 for (size_t i = 0; i != kChunkSize; i += N) {
279 Stream(Load(d, buf + i), d, out + begin + i);
280 }
281 });
283
284 // Handle remainder rows which are not a multiple of the chunk size.
285 for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
286 auto sum0 = Zero(d);
287
288 const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
289 size_t i = 0;
290 HWY_UNROLL(1)
291 for (; i + N <= kInner; i += N) {
292 const V16H b0 = LoadU(d16h, row + i);
293 const V a0 = PromoteTo(d, b0);
294 const V v0 = LoadU(d, vec + i);
295 sum0 = MulAdd(a0, v0, sum0);
296 }
297 const size_t remainder = kInner - i;
298 if (remainder != 0) {
299 const V16H b0 = LoadN(d16h, row + i, remainder);
300 const V a0 = PromoteTo(d, b0);
301 const V v0 = LoadN(d, vec + i, remainder);
302 sum0 = MulAdd(a0, v0, sum0);
303 }
304 out[r] = ReduceSum(d, sum0);
305 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
306 } // r
307}
308
309template <size_t kOuter, size_t kInner>
311 const float* HWY_RESTRICT vec,
312 const float* HWY_RESTRICT add,
313 float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
314 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
315}
316
317template <size_t kOuter, size_t kInner>
319 const float* HWY_RESTRICT vec, float* HWY_RESTRICT out,
320 hwy::ThreadPool& pool) {
321 MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
322}
323
324// Both mat and vec are bf16.
325template <size_t kOuter, size_t kInner, bool kAdd>
327 const hwy::bfloat16_t* HWY_RESTRICT vec,
328 const hwy::bfloat16_t* HWY_RESTRICT add,
329 float* HWY_RESTRICT out,
330 hwy::ThreadPool& pool) {
331 // Process multiple rows at a time so that we write multiples of a cache line
332 // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
333 // parallelization potential.
334 constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
335 const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
336
337 const ScalableTag<float> df;
338 const Repartition<hwy::bfloat16_t, decltype(df)> d16;
339 using V16 = Vec<decltype(d16)>;
340 const size_t N = Lanes(d16);
341 // Required for Stream loop, otherwise we might have partial vectors.
342 HWY_DASSERT(kChunkSize >= N);
343 pool.Run(0, num_chunks,
344 [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
345 // MSVC workaround: duplicate to ensure constexpr.
346 constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
347 // Software write-combining to avoid cache pollution from out.
348 // Although `out` may be used later, keeping it out of the cache
349 // now and avoiding RFOs is a consistent 5% overall win.
350 HWY_ALIGN float buf[kChunkSize];
351
352 // Only handle entire chunks here because the Stream is not masked.
353 // Remaining rows are handled after the pool.Run.
354 const size_t begin = static_cast<size_t>(chunk * kChunkSize);
355 for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
356 auto sum0 = Zero(df);
357 auto sum1 = Zero(df);
358 auto sum2 = Zero(df);
359 auto sum3 = Zero(df);
360
361 const hwy::bfloat16_t* HWY_RESTRICT row =
362 &mat[(begin + idx_row) * kInner];
363 size_t i = 0;
364 // No clear win from prefetching from the next 1..3 rows.
365 // clflush &row[i] is slow, clflushopt less so but not helping.
366 HWY_UNROLL(1)
367 for (; i + 2 * N <= kInner; i += 2 * N) {
368 const V16 b0 = LoadU(d16, row + i + 0 * N);
369 const V16 b1 = LoadU(d16, row + i + 1 * N);
370 const V16 v0 = LoadU(d16, vec + i + 0 * N);
371 const V16 v1 = LoadU(d16, vec + i + 1 * N);
372 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
373 sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3);
374 }
375 // Last entire vector
376 for (; i + N <= kInner; i += N) {
377 const V16 b0 = LoadU(d16, row + i);
378 const V16 v0 = LoadU(d16, vec + i);
379 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
380 }
381 const size_t remainder = kInner - i;
382 if (remainder != 0) {
383 const V16 b0 = LoadN(d16, row + i, remainder);
384 const V16 v0 = LoadN(d16, vec + i, remainder);
385 sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3);
386 }
387 // Reduction tree: sum of all accumulators, then their lanes
388 sum0 = Add(sum0, sum1);
389 sum2 = Add(sum2, sum3);
390 sum0 = Add(sum0, sum2);
391 buf[idx_row] = ReduceSum(df, sum0);
392 HWY_IF_CONSTEXPR(kAdd) {
393 buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
394 }
395 } // idx_row
396 HWY_UNROLL(4) // 1..4 iterations
397 for (size_t i = 0; i != kChunkSize; i += N / 2) {
398 Stream(Load(df, buf + i), df, out + begin + i);
399 }
400 });
402
403 // Handle remainder rows which are not a multiple of the chunk size.
404 for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
405 auto sum0 = Zero(df);
406 auto sum1 = Zero(df);
407
408 const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
409 size_t i = 0;
410 HWY_UNROLL(1)
411 for (; i + N <= kInner; i += N) {
412 const V16 b0 = LoadU(d16, row + i);
413 const V16 v0 = LoadU(d16, vec + i);
414 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
415 }
416 const size_t remainder = kInner - i;
417 if (remainder != 0) {
418 const V16 b0 = LoadN(d16, row + i, remainder);
419 const V16 v0 = LoadN(d16, vec + i, remainder);
420 sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
421 }
422 out[r] = ReduceSum(df, Add(sum0, sum1));
423 HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
424 } // r
425}
426
427template <size_t kOuter, size_t kInner>
429 const hwy::bfloat16_t* HWY_RESTRICT vec,
430 const hwy::bfloat16_t* HWY_RESTRICT add,
431 float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
432 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
433}
434
435template <size_t kOuter, size_t kInner>
437 const hwy::bfloat16_t* HWY_RESTRICT vec,
438 float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
439 MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
440}
441
442#endif // HWY_TARGET != HWY_SCALAR
443
444// NOLINTNEXTLINE(google-readability-namespace-comments)
445} // namespace HWY_NAMESPACE
446} // namespace hwy
448
449#endif // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
#define HWY_RESTRICT
Definition base.h:95
#define HWY_NOINLINE
Definition base.h:103
#define HWY_IF_CONSTEXPR
Definition base.h:310
#define HWY_DASSERT(condition)
Definition base.h:290
#define HWY_UNROLL(factor)
Definition base.h:187
Definition thread_pool.h:525
void Run(uint64_t begin, uint64_t end, const Closure &closure)
Definition thread_pool.h:627
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
HWY_NOINLINE void MatVecAdd(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, const T *HWY_RESTRICT add, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:159
D d
Definition arm_sve-inl.h:1915
HWY_NOINLINE void MatVec(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:175
HWY_API VFromD< D32 > ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, const VFromD< D32 > sum0, VFromD< D32 > &sum1)
Definition arm_neon-inl.h:6571
HWY_API Vec128< T, N > MulAdd(Vec128< T, N > mul, Vec128< T, N > x, Vec128< T, N > add)
Definition arm_neon-inl.h:2550
HWY_API Vec128< uint8_t > LoadU(D, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3442
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
HWY_NOINLINE void MatVecAddImpl(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, const T *HWY_RESTRICT add, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:40
HWY_API VFromD< D > Load(D d, const TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3664
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
HWY_API VFromD< D > PromoteLowerTo(D d, V v)
Definition generic_ops-inl.h:2984
HWY_API Vec128< uint16_t > PromoteTo(D, Vec64< uint8_t > v)
Definition arm_neon-inl.h:4252
typename detail::ScalableTagChecker< T, kPow2 >::type ScalableTag
Definition ops/shared-inl.h:367
typename D::Half Half
Definition ops/shared-inl.h:487
TA AddScalar(TA a, TB b)
Definition matvec-inl.h:34
HWY_API void Stream(const VFromD< D > v, D d, TFromD< D > *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:3932
HWY_API VFromD< D > LoadN(D d, const TFromD< D > *HWY_RESTRICT p, size_t max_lanes_to_load)
Definition emu128-inl.h:1352
HWY_API VFromD< D > PromoteUpperTo(D d, V v)
Definition arm_sve-inl.h:2228
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:46
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API TFromD< D > ReduceSum(D, VFromD< D > v)
Definition arm_neon-inl.h:8027
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:471
Definition abort.h:8
HWY_INLINE HWY_ATTR_CACHE void FlushStream()
Definition cache_control.h:73
#define HWY_ALIGN
Definition set_macros-inl.h:167
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
#define HWY_ATTR
Definition set_macros-inl.h:646
Definition base.h:1594