17#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE)
19#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
20#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
52 template <
int kAssumptions,
class D,
typename T = TFromD<D>>
55 const size_t num_elements) {
56 static_assert(IsFloat<T>(),
"MulAdd requires float type");
57 using V =
decltype(
Zero(
d));
62 constexpr bool kIsAtLeastOneVector =
64 constexpr bool kIsMultipleOfVector =
66 constexpr bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
69 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
72 T sum0 = ConvertScalarTo<T>(0);
73 T sum1 = ConvertScalarTo<T>(0);
74 for (; i + 2 <= num_elements; i += 2) {
76 sum0 = ConvertScalarTo<T>(sum0 + pa[i + 0] * pb[i + 0]);
77 sum1 = ConvertScalarTo<T>(sum1 + pa[i + 1] * pb[i + 1]);
79 if (i < num_elements) {
80 sum1 = ConvertScalarTo<T>(sum1 + pa[i] * pb[i]);
82 return ConvertScalarTo<T>(sum0 + sum1);
96 for (; i + 4 * N <= num_elements; ) {
97 const auto a0 =
LoadU(
d, pa + i);
98 const auto b0 =
LoadU(
d, pb + i);
100 sum0 =
MulAdd(a0, b0, sum0);
101 const auto a1 =
LoadU(
d, pa + i);
102 const auto b1 =
LoadU(
d, pb + i);
104 sum1 =
MulAdd(a1, b1, sum1);
105 const auto a2 =
LoadU(
d, pa + i);
106 const auto b2 =
LoadU(
d, pb + i);
108 sum2 =
MulAdd(a2, b2, sum2);
109 const auto a3 =
LoadU(
d, pa + i);
110 const auto b3 =
LoadU(
d, pb + i);
112 sum3 =
MulAdd(a3, b3, sum3);
116 for (; i + N <= num_elements; i += N) {
117 const auto a =
LoadU(
d, pa + i);
118 const auto b =
LoadU(
d, pb + i);
119 sum0 =
MulAdd(a, b, sum0);
122 if (!kIsMultipleOfVector) {
123 const size_t remaining = num_elements - i;
124 if (remaining != 0) {
125 if (kIsPaddedToVector) {
126 const auto mask =
FirstN(
d, remaining);
127 const auto a =
LoadU(
d, pa + i);
128 const auto b =
LoadU(
d, pb + i);
136 const auto skip =
FirstN(
d, N - remaining);
137 const auto a =
LoadU(
d, pa + i);
138 const auto b =
LoadU(
d, pb + i);
145 sum0 =
Add(sum0, sum1);
146 sum2 =
Add(sum2, sum3);
147 sum0 =
Add(sum0, sum2);
152 template <
int kAssumptions,
class DF, HWY_IF_F32_D(DF)>
156 const size_t num_elements) {
157#if HWY_TARGET == HWY_SCALAR
161 using VBF =
decltype(
Zero(dbf));
163 const Half<
decltype(dbf)> dbfh;
164 using VF =
decltype(
Zero(df));
166 const size_t NF =
Lanes(df);
168 constexpr bool kIsAtLeastOneVector =
170 constexpr bool kIsMultipleOfVector =
172 constexpr bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
175 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
181 for (; i + 2 <= num_elements; i += 2) {
182 sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]);
183 sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]);
185 for (; i < num_elements; ++i) {
186 sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
203#if HWY_TARGET != HWY_SCALAR
205 for (; i + 4 * NF <= num_elements; ) {
206 const VF a0 =
LoadU(df, pa + i);
207 const VBF b0 =
LoadU(dbf, pb + i);
210 const VF a1 =
LoadU(df, pa + i);
213 const VF a2 =
LoadU(df, pa + i);
214 const VBF b2 =
LoadU(dbf, pb + i);
217 const VF a3 =
LoadU(df, pa + i);
224 for (; i + NF <= num_elements; i += NF) {
225 const VF a =
LoadU(df, pa + i);
227 sum0 =
MulAdd(a, b, sum0);
230 if (!kIsMultipleOfVector) {
231 const size_t remaining = num_elements - i;
232 if (remaining != 0) {
233 if (kIsPaddedToVector) {
234 const auto mask =
FirstN(df, remaining);
235 const VF a =
LoadU(df, pa + i);
244 const auto skip =
FirstN(df, NF - remaining);
245 const VF a =
LoadU(df, pa + i);
253 sum0 =
Add(sum0, sum1);
254 sum2 =
Add(sum2, sum3);
255 sum0 =
Add(sum0, sum2);
261 template <
int kAssumptions,
class D, HWY_IF_BF16_D(D)>
265 const size_t num_elements) {
269 using V =
decltype(
Zero(df32));
270 const size_t N =
Lanes(
d);
273 constexpr bool kIsAtLeastOneVector =
275 constexpr bool kIsMultipleOfVector =
277 constexpr bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
280 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
284 for (; i + 2 <= num_elements; i += 2) {
288 if (i < num_elements) {
302 for (; i + 2 * N <= num_elements; ) {
303 const auto a0 =
LoadU(
d, pa + i);
304 const auto b0 =
LoadU(
d, pb + i);
307 const auto a1 =
LoadU(
d, pa + i);
308 const auto b1 =
LoadU(
d, pb + i);
314 if (i + N <= num_elements) {
315 const auto a0 =
LoadU(
d, pa + i);
316 const auto b0 =
LoadU(
d, pb + i);
321 if (!kIsMultipleOfVector) {
322 const size_t remaining = num_elements - i;
323 if (remaining != 0) {
324 if (kIsPaddedToVector) {
325 const auto mask =
FirstN(du16, remaining);
326 const auto va =
LoadU(
d, pa + i);
327 const auto vb =
LoadU(
d, pb + i);
338 const auto skip =
FirstN(du16, N - remaining);
339 const auto va =
LoadU(
d, pa + i);
340 const auto vb =
LoadU(
d, pb + i);
349 sum0 =
Add(sum0, sum1);
350 sum2 =
Add(sum2, sum3);
351 sum0 =
Add(sum0, sum2);
#define HWY_RESTRICT
Definition base.h:95
#define HWY_INLINE
Definition base.h:101
#define HWY_DASSERT(condition)
Definition base.h:290
#define HWY_UNLIKELY(expr)
Definition base.h:107
typename D::template Rebind< T > Rebind
Definition ops/shared-inl.h:460
D d
Definition arm_sve-inl.h:1915
HWY_API VFromD< D > BitCast(D d, Vec128< FromT, Repartition< FromT, D >().MaxLanes()> v)
Definition arm_neon-inl.h:1581
HWY_API VFromD< D32 > ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, const VFromD< D32 > sum0, VFromD< D32 > &sum1)
Definition arm_neon-inl.h:6571
HWY_API Vec128< T, N > MulAdd(Vec128< T, N > mul, Vec128< T, N > x, Vec128< T, N > add)
Definition arm_neon-inl.h:2550
HWY_API Vec128< T, N > IfThenZeroElse(Mask128< T, N > mask, Vec128< T, N > no)
Definition arm_neon-inl.h:3019
HWY_API Vec128< uint8_t > LoadU(D, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3442
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
HWY_API VFromD< D > PromoteLowerTo(D d, V v)
Definition generic_ops-inl.h:2984
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition ops/shared-inl.h:465
HWY_API Vec128< uint16_t > PromoteTo(D, Vec64< uint8_t > v)
Definition arm_neon-inl.h:4252
HWY_API Vec128< T, N > IfThenElseZero(Mask128< T, N > mask, Vec128< T, N > yes)
Definition arm_neon-inl.h:3007
typename D::Half Half
Definition ops/shared-inl.h:487
HWY_API VFromD< D > PromoteUpperTo(D d, V v)
Definition arm_sve-inl.h:2228
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API MFromD< D > FirstN(D d, size_t num)
Definition arm_neon-inl.h:3232
HWY_API TFromD< D > ReduceSum(D, VFromD< D > v)
Definition arm_neon-inl.h:8027
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:471
HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf)
Definition base.h:1778
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:262
static HWY_INLINE float Compute(const DF df, const float *const HWY_RESTRICT pa, const hwy::bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:153
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition dot-inl.h:53
Assumptions
Definition dot-inl.h:37
@ kMultipleOfVector
Definition dot-inl.h:43
@ kPaddedToVector
Definition dot-inl.h:46
@ kAtLeastOneVector
Definition dot-inl.h:39