17#if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \
18 defined(HWY_TARGET_TOGGLE)
19#ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
20#undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
22#define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
33template <
typename TA,
typename TB>
35 return ConvertScalarTo<TA>(ConvertScalarTo<float>(a) +
36 ConvertScalarTo<float>(b));
39template <
size_t kOuter,
size_t kInner,
typename T,
bool kAdd>
49 constexpr size_t kChunkSize = 64 /
sizeof(T);
50 const uint64_t num_chunks =
static_cast<uint64_t
>(kOuter / kChunkSize);
56 pool.
Run(0, num_chunks,
57 [&](
const uint64_t chunk,
size_t )
HWY_ATTR {
59 constexpr size_t kChunkSize = 64 /
sizeof(T);
67 const size_t begin =
static_cast<size_t>(chunk * kChunkSize);
68 for (
size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
75 const T*
HWY_RESTRICT row = &mat[(begin + idx_row) * kInner];
80 for (; i + 4 * N <= kInner; i += 4 * N) {
81 const auto a0 =
LoadU(
d, row + i + 0 * N);
82 const auto v0 =
LoadU(
d, vec + i + 0 * N);
83 sum0 =
MulAdd(a0, v0, sum0);
85 const auto a1 =
LoadU(
d, row + i + 1 * N);
86 const auto v1 =
LoadU(
d, vec + i + 1 * N);
87 sum1 =
MulAdd(a1, v1, sum1);
89 const auto a2 =
LoadU(
d, row + i + 2 * N);
90 const auto v2 =
LoadU(
d, vec + i + 2 * N);
91 sum2 =
MulAdd(a2, v2, sum2);
93 const auto a3 =
LoadU(
d, row + i + 3 * N);
94 const auto v3 =
LoadU(
d, vec + i + 3 * N);
95 sum3 =
MulAdd(a3, v3, sum3);
98 for (; i + N <= kInner; i += N) {
99 const auto a0 =
LoadU(
d, row + i);
100 const auto v0 =
LoadU(
d, vec + i);
101 sum0 =
MulAdd(a0, v0, sum0);
103 const size_t remainder = kInner - i;
104 if (remainder != 0) {
105 const auto a0 =
LoadN(
d, row + i, remainder);
106 const auto v0 =
LoadN(
d, vec + i, remainder);
107 sum1 =
MulAdd(a0, v0, sum1);
110 sum2 =
Add(sum2, sum3);
111 sum0 =
Add(sum0, sum1);
112 sum0 =
Add(sum0, sum2);
115 buf[idx_row] =
AddScalar(buf[idx_row], add[begin + idx_row]);
119 for (
size_t i = 0; i != kChunkSize; i += N) {
126 for (
size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
132 for (; i + N <= kInner; i += N) {
133 const auto a0 =
LoadU(
d, row + i);
134 const auto v0 =
LoadU(
d, vec + i);
135 sum0 =
MulAdd(a0, v0, sum0);
137 const size_t remainder = kInner - i;
138 if (remainder != 0) {
139 const auto a0 =
LoadN(
d, row + i, remainder);
140 const auto v0 =
LoadN(
d, vec + i, remainder);
141 sum0 =
MulAdd(a0, v0, sum0);
158template <
size_t kOuter,
size_t kInner,
typename T>
163 MatVecAddImpl<kOuter, kInner, T, true>(mat, vec, add, out, pool);
174template <
size_t kOuter,
size_t kInner,
typename T>
177 MatVecAddImpl<kOuter, kInner, T, false>(mat, vec,
nullptr, out, pool);
182#if HWY_TARGET != HWY_SCALAR
185template <
size_t kOuter,
size_t kInner,
bool kAdd>
194 constexpr size_t kChunkSize = 64 /
sizeof(float);
195 const uint64_t num_chunks =
static_cast<uint64_t
>(kOuter / kChunkSize);
197 const ScalableTag<float>
d;
201 const Half<
decltype(d16)> d16h;
202 using V =
Vec<
decltype(
d)>;
203 using V16 =
Vec<
decltype(d16)>;
204 using V16H =
Vec<
decltype(d16h)>;
205 const size_t N =
Lanes(
d);
208 pool.
Run(0, num_chunks,
209 [&](
const uint64_t chunk,
size_t )
HWY_ATTR {
211 constexpr size_t kChunkSize = 64 /
sizeof(float);
219 const size_t begin =
static_cast<size_t>(chunk * kChunkSize);
220 for (
size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
228 &mat[(begin + idx_row) * kInner];
233 for (; i + 4 * N <= kInner; i += 4 * N) {
234 const V16 b0 =
LoadU(d16, row + i + 0 * N);
238 const V16 b1 =
LoadU(d16, row + i + 2 * N);
242 const V v0 =
LoadU(
d, vec + i + 0 * N);
243 sum0 =
MulAdd(a0, v0, sum0);
245 const V v1 =
LoadU(
d, vec + i + 1 * N);
246 sum1 =
MulAdd(a1, v1, sum1);
248 const V v2 =
LoadU(
d, vec + i + 2 * N);
249 sum2 =
MulAdd(a2, v2, sum2);
251 const V v3 =
LoadU(
d, vec + i + 3 * N);
252 sum3 =
MulAdd(a3, v3, sum3);
255 for (; i + N <= kInner; i += N) {
256 const V16H b0 =
LoadU(d16h, row + i);
258 const V v0 =
LoadU(
d, vec + i);
259 sum0 =
MulAdd(a0, v0, sum0);
261 const size_t remainder = kInner - i;
262 if (remainder != 0) {
263 const V16H b0 =
LoadN(d16h, row + i, remainder);
265 const V v0 =
LoadN(
d, vec + i, remainder);
266 sum1 =
MulAdd(a0, v0, sum1);
269 sum2 =
Add(sum2, sum3);
270 sum0 =
Add(sum0, sum1);
271 sum0 =
Add(sum0, sum2);
274 buf[idx_row] =
AddScalar(buf[idx_row], add[begin + idx_row]);
278 for (
size_t i = 0; i != kChunkSize; i += N) {
285 for (
size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
291 for (; i + N <= kInner; i += N) {
292 const V16H b0 =
LoadU(d16h, row + i);
294 const V v0 =
LoadU(
d, vec + i);
295 sum0 =
MulAdd(a0, v0, sum0);
297 const size_t remainder = kInner - i;
298 if (remainder != 0) {
299 const V16H b0 =
LoadN(d16h, row + i, remainder);
301 const V v0 =
LoadN(
d, vec + i, remainder);
302 sum0 =
MulAdd(a0, v0, sum0);
309template <
size_t kOuter,
size_t kInner>
314 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
317template <
size_t kOuter,
size_t kInner>
321 MatVecAddImpl<kOuter, kInner, false>(mat, vec,
nullptr, out, pool);
325template <
size_t kOuter,
size_t kInner,
bool kAdd>
334 constexpr size_t kChunkSize = 64 /
sizeof(bfloat16_t);
335 const uint64_t num_chunks =
static_cast<uint64_t
>(kOuter / kChunkSize);
337 const ScalableTag<float> df;
339 using V16 =
Vec<
decltype(d16)>;
340 const size_t N =
Lanes(d16);
343 pool.
Run(0, num_chunks,
344 [&](
const uint64_t chunk,
size_t )
HWY_ATTR {
346 constexpr size_t kChunkSize = 64 /
sizeof(bfloat16_t);
354 const size_t begin =
static_cast<size_t>(chunk * kChunkSize);
355 for (
size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
356 auto sum0 =
Zero(df);
357 auto sum1 =
Zero(df);
358 auto sum2 =
Zero(df);
359 auto sum3 =
Zero(df);
362 &mat[(begin + idx_row) * kInner];
367 for (; i + 2 * N <= kInner; i += 2 * N) {
368 const V16 b0 =
LoadU(d16, row + i + 0 * N);
369 const V16 b1 =
LoadU(d16, row + i + 1 * N);
370 const V16 v0 =
LoadU(d16, vec + i + 0 * N);
371 const V16 v1 =
LoadU(d16, vec + i + 1 * N);
376 for (; i + N <= kInner; i += N) {
377 const V16 b0 =
LoadU(d16, row + i);
378 const V16 v0 =
LoadU(d16, vec + i);
381 const size_t remainder = kInner - i;
382 if (remainder != 0) {
383 const V16 b0 =
LoadN(d16, row + i, remainder);
384 const V16 v0 =
LoadN(d16, vec + i, remainder);
388 sum0 =
Add(sum0, sum1);
389 sum2 =
Add(sum2, sum3);
390 sum0 =
Add(sum0, sum2);
393 buf[idx_row] =
AddScalar(buf[idx_row], add[begin + idx_row]);
397 for (
size_t i = 0; i != kChunkSize; i += N / 2) {
398 Stream(
Load(df, buf + i), df, out + begin + i);
404 for (
size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
405 auto sum0 =
Zero(df);
406 auto sum1 =
Zero(df);
411 for (; i + N <= kInner; i += N) {
412 const V16 b0 =
LoadU(d16, row + i);
413 const V16 v0 =
LoadU(d16, vec + i);
416 const size_t remainder = kInner - i;
417 if (remainder != 0) {
418 const V16 b0 =
LoadN(d16, row + i, remainder);
419 const V16 v0 =
LoadN(d16, vec + i, remainder);
427template <
size_t kOuter,
size_t kInner>
432 MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
435template <
size_t kOuter,
size_t kInner>
439 MatVecAddImpl<kOuter, kInner, false>(mat, vec,
nullptr, out, pool);
#define HWY_RESTRICT
Definition base.h:95
#define HWY_NOINLINE
Definition base.h:103
#define HWY_IF_CONSTEXPR
Definition base.h:310
#define HWY_DASSERT(condition)
Definition base.h:290
#define HWY_UNROLL(factor)
Definition base.h:187
Definition thread_pool.h:525
void Run(uint64_t begin, uint64_t end, const Closure &closure)
Definition thread_pool.h:627
HWY_NOINLINE void MatVecAdd(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, const T *HWY_RESTRICT add, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:159
D d
Definition arm_sve-inl.h:1915
HWY_NOINLINE void MatVec(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:175
HWY_API VFromD< D32 > ReorderWidenMulAccumulate(D32 df32, V16 a, V16 b, const VFromD< D32 > sum0, VFromD< D32 > &sum1)
Definition arm_neon-inl.h:6571
HWY_API Vec128< T, N > MulAdd(Vec128< T, N > mul, Vec128< T, N > x, Vec128< T, N > add)
Definition arm_neon-inl.h:2550
HWY_API Vec128< uint8_t > LoadU(D, const uint8_t *HWY_RESTRICT unaligned)
Definition arm_neon-inl.h:3442
HWY_API VFromD< D > Zero(D d)
Definition arm_neon-inl.h:947
HWY_NOINLINE void MatVecAddImpl(const T *HWY_RESTRICT mat, const T *HWY_RESTRICT vec, const T *HWY_RESTRICT add, T *HWY_RESTRICT out, hwy::ThreadPool &pool)
Definition matvec-inl.h:40
HWY_API VFromD< D > Load(D d, const TFromD< D > *HWY_RESTRICT p)
Definition arm_neon-inl.h:3664
HWY_API V Add(V a, V b)
Definition generic_ops-inl.h:7300
HWY_API VFromD< D > PromoteLowerTo(D d, V v)
Definition generic_ops-inl.h:2984
HWY_API Vec128< uint16_t > PromoteTo(D, Vec64< uint8_t > v)
Definition arm_neon-inl.h:4252
typename detail::ScalableTagChecker< T, kPow2 >::type ScalableTag
Definition ops/shared-inl.h:367
typename D::Half Half
Definition ops/shared-inl.h:487
TA AddScalar(TA a, TB b)
Definition matvec-inl.h:34
HWY_API void Stream(const VFromD< D > v, D d, TFromD< D > *HWY_RESTRICT aligned)
Definition arm_neon-inl.h:3932
HWY_API VFromD< D > LoadN(D d, const TFromD< D > *HWY_RESTRICT p, size_t max_lanes_to_load)
Definition emu128-inl.h:1352
HWY_API VFromD< D > PromoteUpperTo(D d, V v)
Definition arm_sve-inl.h:2228
decltype(Zero(D())) Vec
Definition generic_ops-inl.h:46
HWY_API size_t Lanes(D)
Definition rvv-inl.h:598
HWY_API TFromD< D > ReduceSum(D, VFromD< D > v)
Definition arm_neon-inl.h:8027
typename D::template Repartition< T > Repartition
Definition ops/shared-inl.h:471
HWY_INLINE HWY_ATTR_CACHE void FlushStream()
Definition cache_control.h:73
#define HWY_ALIGN
Definition set_macros-inl.h:167
#define HWY_NAMESPACE
Definition set_macros-inl.h:166
#define HWY_ATTR
Definition set_macros-inl.h:646