Grok 12.0.1
dot-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// clang-format off
17#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT
18// clang-format on
19#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
20#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
21#else
22#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
23#endif
24
25#include <stddef.h>
26
27#include "hwy/highway.h"
28
30namespace hwy {
31namespace HWY_NAMESPACE {
32
33struct Dot {
34 // Specify zero or more of these, ORed together, as the kAssumptions template
35 // argument to Compute. Each one may improve performance or reduce code size,
36 // at the cost of additional requirements on the arguments.
38 // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
40 // num_elements is divisible by N (a power of two, so this can be used if
41 // the problem size is known to be a power of two >= HWY_MAX_BYTES /
42 // sizeof(T)).
44 // RoundUpTo(num_elements, N) elements are accessible; their value does not
45 // matter (will be treated as if they were zero).
47 };
48
49 // Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t
50 // and double if HWY_HAVE_FLOAT16/64. Aligning the
51 // pointers to a multiple of N elements is helpful but not required.
52 template <int kAssumptions, class D, typename T = TFromD<D>>
53 static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
54 const T* const HWY_RESTRICT pb,
55 const size_t num_elements) {
56 static_assert(IsFloat<T>(), "MulAdd requires float type");
57 using V = decltype(Zero(d));
58
59 const size_t N = Lanes(d);
60 size_t i = 0;
61
62 constexpr bool kIsAtLeastOneVector =
63 (kAssumptions & kAtLeastOneVector) != 0;
64 constexpr bool kIsMultipleOfVector =
65 (kAssumptions & kMultipleOfVector) != 0;
66 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
67
68 // Won't be able to do a full vector load without padding => scalar loop.
69 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
70 HWY_UNLIKELY(num_elements < N)) {
71 // Only 2x unroll to avoid excessive code size.
72 T sum0 = ConvertScalarTo<T>(0);
73 T sum1 = ConvertScalarTo<T>(0);
74 for (; i + 2 <= num_elements; i += 2) {
75 // For reasons unknown, fp16 += does not compile on clang (Arm).
76 sum0 = ConvertScalarTo<T>(sum0 + pa[i + 0] * pb[i + 0]);
77 sum1 = ConvertScalarTo<T>(sum1 + pa[i + 1] * pb[i + 1]);
78 }
79 if (i < num_elements) {
80 sum1 = ConvertScalarTo<T>(sum1 + pa[i] * pb[i]);
81 }
82 return ConvertScalarTo<T>(sum0 + sum1);
83 }
84
85 // Compiler doesn't make independent sum* accumulators, so unroll manually.
86 // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
87 // for unaligned inputs (each unaligned pointer halves the throughput
88 // because it occupies both L1 load ports for a cycle). We cannot have
89 // arrays of vectors on RVV/SVE, so always unroll 4x.
90 V sum0 = Zero(d);
91 V sum1 = Zero(d);
92 V sum2 = Zero(d);
93 V sum3 = Zero(d);
94
95 // Main loop: unrolled
96 for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
97 const auto a0 = LoadU(d, pa + i);
98 const auto b0 = LoadU(d, pb + i);
99 i += N;
100 sum0 = MulAdd(a0, b0, sum0);
101 const auto a1 = LoadU(d, pa + i);
102 const auto b1 = LoadU(d, pb + i);
103 i += N;
104 sum1 = MulAdd(a1, b1, sum1);
105 const auto a2 = LoadU(d, pa + i);
106 const auto b2 = LoadU(d, pb + i);
107 i += N;
108 sum2 = MulAdd(a2, b2, sum2);
109 const auto a3 = LoadU(d, pa + i);
110 const auto b3 = LoadU(d, pb + i);
111 i += N;
112 sum3 = MulAdd(a3, b3, sum3);
113 }
114
115 // Up to 3 iterations of whole vectors
116 for (; i + N <= num_elements; i += N) {
117 const auto a = LoadU(d, pa + i);
118 const auto b = LoadU(d, pb + i);
119 sum0 = MulAdd(a, b, sum0);
120 }
121
122 if (!kIsMultipleOfVector) {
123 const size_t remaining = num_elements - i;
124 if (remaining != 0) {
125 if (kIsPaddedToVector) {
126 const auto mask = FirstN(d, remaining);
127 const auto a = LoadU(d, pa + i);
128 const auto b = LoadU(d, pb + i);
129 sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
130 } else {
131 // Unaligned load such that the last element is in the highest lane -
132 // ensures we do not touch any elements outside the valid range.
133 // If we get here, then num_elements >= N.
134 HWY_DASSERT(i >= N);
135 i += remaining - N;
136 const auto skip = FirstN(d, N - remaining);
137 const auto a = LoadU(d, pa + i); // always unaligned
138 const auto b = LoadU(d, pb + i);
139 sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
140 }
141 }
142 } // kMultipleOfVector
143
144 // Reduction tree: sum of all accumulators by pairs, then across lanes.
145 sum0 = Add(sum0, sum1);
146 sum2 = Add(sum2, sum3);
147 sum0 = Add(sum0, sum2);
148 return ReduceSum(d, sum0);
149 }
150
151 // f32 * bf16
152 template <int kAssumptions, class DF, HWY_IF_F32_D(DF)>
153 static HWY_INLINE float Compute(const DF df,
154 const float* const HWY_RESTRICT pa,
155 const hwy::bfloat16_t* const HWY_RESTRICT pb,
156 const size_t num_elements) {
157#if HWY_TARGET == HWY_SCALAR
159#else
161 using VBF = decltype(Zero(dbf));
162#endif
163 const Half<decltype(dbf)> dbfh;
164 using VF = decltype(Zero(df));
165
166 const size_t NF = Lanes(df);
167
168 constexpr bool kIsAtLeastOneVector =
169 (kAssumptions & kAtLeastOneVector) != 0;
170 constexpr bool kIsMultipleOfVector =
171 (kAssumptions & kMultipleOfVector) != 0;
172 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
173
174 // Won't be able to do a full vector load without padding => scalar loop.
175 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
176 HWY_UNLIKELY(num_elements < NF)) {
177 // Only 2x unroll to avoid excessive code size.
178 float sum0 = 0.0f;
179 float sum1 = 0.0f;
180 size_t i = 0;
181 for (; i + 2 <= num_elements; i += 2) {
182 sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]);
183 sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]);
184 }
185 for (; i < num_elements; ++i) {
186 sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
187 }
188 return sum0 + sum1;
189 }
190
191 // Compiler doesn't make independent sum* accumulators, so unroll manually.
192 // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
193 // for unaligned inputs (each unaligned pointer halves the throughput
194 // because it occupies both L1 load ports for a cycle). We cannot have
195 // arrays of vectors on RVV/SVE, so always unroll 4x.
196 VF sum0 = Zero(df);
197 VF sum1 = Zero(df);
198 VF sum2 = Zero(df);
199 VF sum3 = Zero(df);
200
201 size_t i = 0;
202
203#if HWY_TARGET != HWY_SCALAR // PromoteUpperTo supported
204 // Main loop: unrolled
205 for (; i + 4 * NF <= num_elements; /* i += 4 * N */) { // incr in loop
206 const VF a0 = LoadU(df, pa + i);
207 const VBF b0 = LoadU(dbf, pb + i);
208 i += NF;
209 sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0);
210 const VF a1 = LoadU(df, pa + i);
211 i += NF;
212 sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1);
213 const VF a2 = LoadU(df, pa + i);
214 const VBF b2 = LoadU(dbf, pb + i);
215 i += NF;
216 sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2);
217 const VF a3 = LoadU(df, pa + i);
218 i += NF;
219 sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3);
220 }
221#endif // HWY_TARGET == HWY_SCALAR
222
223 // Up to 3 iterations of whole vectors
224 for (; i + NF <= num_elements; i += NF) {
225 const VF a = LoadU(df, pa + i);
226 const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
227 sum0 = MulAdd(a, b, sum0);
228 }
229
230 if (!kIsMultipleOfVector) {
231 const size_t remaining = num_elements - i;
232 if (remaining != 0) {
233 if (kIsPaddedToVector) {
234 const auto mask = FirstN(df, remaining);
235 const VF a = LoadU(df, pa + i);
236 const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
237 sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
238 } else {
239 // Unaligned load such that the last element is in the highest lane -
240 // ensures we do not touch any elements outside the valid range.
241 // If we get here, then num_elements >= N.
242 HWY_DASSERT(i >= NF);
243 i += remaining - NF;
244 const auto skip = FirstN(df, NF - remaining);
245 const VF a = LoadU(df, pa + i); // always unaligned
246 const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
247 sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
248 }
249 }
250 } // kMultipleOfVector
251
252 // Reduction tree: sum of all accumulators by pairs, then across lanes.
253 sum0 = Add(sum0, sum1);
254 sum2 = Add(sum2, sum3);
255 sum0 = Add(sum0, sum2);
256 return ReduceSum(df, sum0);
257 }
258
259 // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
260 // multiple of N elements is helpful but not required.
261 template <int kAssumptions, class D, HWY_IF_BF16_D(D)>
262 static HWY_INLINE float Compute(const D d,
263 const bfloat16_t* const HWY_RESTRICT pa,
264 const bfloat16_t* const HWY_RESTRICT pb,
265 const size_t num_elements) {
266 const RebindToUnsigned<D> du16;
267 const Repartition<float, D> df32;
268
269 using V = decltype(Zero(df32));
270 const size_t N = Lanes(d);
271 size_t i = 0;
272
273 constexpr bool kIsAtLeastOneVector =
274 (kAssumptions & kAtLeastOneVector) != 0;
275 constexpr bool kIsMultipleOfVector =
276 (kAssumptions & kMultipleOfVector) != 0;
277 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
278
279 // Won't be able to do a full vector load without padding => scalar loop.
280 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
281 HWY_UNLIKELY(num_elements < N)) {
282 float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
283 float sum1 = 0.0f; // this unlikely(?) case.
284 for (; i + 2 <= num_elements; i += 2) {
285 sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
286 sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
287 }
288 if (i < num_elements) {
289 sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
290 }
291 return sum0 + sum1;
292 }
293
294 // See comment in the other Compute() overload. Unroll 2x, but we need
295 // twice as many sums for ReorderWidenMulAccumulate.
296 V sum0 = Zero(df32);
297 V sum1 = Zero(df32);
298 V sum2 = Zero(df32);
299 V sum3 = Zero(df32);
300
301 // Main loop: unrolled
302 for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
303 const auto a0 = LoadU(d, pa + i);
304 const auto b0 = LoadU(d, pb + i);
305 i += N;
306 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
307 const auto a1 = LoadU(d, pa + i);
308 const auto b1 = LoadU(d, pb + i);
309 i += N;
310 sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
311 }
312
313 // Possibly one more iteration of whole vectors
314 if (i + N <= num_elements) {
315 const auto a0 = LoadU(d, pa + i);
316 const auto b0 = LoadU(d, pb + i);
317 i += N;
318 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
319 }
320
321 if (!kIsMultipleOfVector) {
322 const size_t remaining = num_elements - i;
323 if (remaining != 0) {
324 if (kIsPaddedToVector) {
325 const auto mask = FirstN(du16, remaining);
326 const auto va = LoadU(d, pa + i);
327 const auto vb = LoadU(d, pb + i);
328 const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
329 const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
330 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
331
332 } else {
333 // Unaligned load such that the last element is in the highest lane -
334 // ensures we do not touch any elements outside the valid range.
335 // If we get here, then num_elements >= N.
336 HWY_DASSERT(i >= N);
337 i += remaining - N;
338 const auto skip = FirstN(du16, N - remaining);
339 const auto va = LoadU(d, pa + i); // always unaligned
340 const auto vb = LoadU(d, pb + i);
341 const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
342 const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
343 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
344 }
345 }
346 } // kMultipleOfVector
347
348 // Reduction tree: sum of all accumulators by pairs, then across lanes.
349 sum0 = Add(sum0, sum1);
350 sum2 = Add(sum2, sum3);
351 sum0 = Add(sum0, sum2);
352 return ReduceSum(df32, sum0);
353 }
354};
355
356// NOLINTNEXTLINE(google-readability-namespace-comments)
357} // namespace HWY_NAMESPACE
358} // namespace hwy
360
361#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#define HWY_RESTRICT
Definition base.h:95
#define HWY_INLINE
Definition base.h:101
#define HWY_DASSERT(condition)
Definition base.h:290
#define HWY_UNLIKELY(expr)
Definition base.h:107
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
typename D::template Rebind< T > Rebind
Definition ops/shared-inl.h:460
D d
Definition arm_sve-inl.h:1915
HWY_API VFromD< D > BitCast(D d, Vec128< FromT, Repartition< FromT, D >().MaxLanes()> v)
Definition arm_neon-inl.h:1581
HWY_API VFromD< D32 > ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, const VFromD< D32 > sum0, VFromD< D32 > &sum1)
Definition arm_neon-inl.h:6571
HWY_API Vec128< T, N > MulAdd(Vec128< T, N > mul, Vec128< T, N > x, Vec128< T, N > add)
Definition arm_neon-inl.h:2550
HWY_API Vec128< T, N > IfThenZeroElse(Mask128< T, N > mask, Vec128< T, N > no)
Definition arm_neon-inl.h:3019
HWY_API Vec128< uint8_t > LoadU(D, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3442
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
HWY_API VFromD< D > PromoteLowerTo(D d, V v)
Definition generic_ops-inl.h:2984
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition ops/shared-inl.h:465
HWY_API Vec128< uint16_t > PromoteTo(D, Vec64< uint8_t > v)
Definition arm_neon-inl.h:4252
HWY_API Vec128< T, N > IfThenElseZero(Mask128< T, N > mask, Vec128< T, N > yes)
Definition arm_neon-inl.h:3007
typename D::Half Half
Definition ops/shared-inl.h:487
HWY_API VFromD< D > PromoteUpperTo(D d, V v)
Definition arm_sve-inl.h:2228
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API MFromD< D > FirstN(D d, size_t num)
Definition arm_neon-inl.h:3232
HWY_API TFromD< D > ReduceSum(D, VFromD< D > v)
Definition arm_neon-inl.h:8027
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:471
Definition abort.h:8
HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf)
Definition base.h:1778
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
Definition dot-inl.h:33
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:262
static HWY_INLINE float Compute(const DF df, const float *const HWY_RESTRICT pa, const hwy::bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:153
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:53
Assumptions
Definition dot-inl.h:37
@ kMultipleOfVector
Definition dot-inl.h:43
@ kPaddedToVector
Definition dot-inl.h:46
@ kAtLeastOneVector
Definition dot-inl.h:39
Definition base.h:1594