Grok 12.0.1
robust_statistics.h
Go to the documentation of this file.
1// Copyright 2023 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_
17#define HIGHWAY_HWY_ROBUST_STATISTICS_H_
18
19#include <algorithm> // std::sort, std::find_if
20#include <limits>
21#include <utility> // std::pair
22#include <vector>
23
24#include "hwy/base.h"
25
26namespace hwy {
27namespace robust_statistics {
28
29// Sorts integral values in ascending order (e.g. for Mode). About 3x faster
30// than std::sort for input distributions with very few unique values.
31template <class T>
32void CountingSort(T* values, size_t num_values) {
33 // Unique values and their frequency (similar to flat_map).
34 using Unique = std::pair<T, int>;
35 std::vector<Unique> unique;
36 for (size_t i = 0; i < num_values; ++i) {
37 const T value = values[i];
38 const auto pos =
39 std::find_if(unique.begin(), unique.end(),
40 [value](const Unique u) { return u.first == value; });
41 if (pos == unique.end()) {
42 unique.push_back(std::make_pair(value, 1));
43 } else {
44 ++pos->second;
45 }
46 }
47
48 // Sort in ascending order of value (pair.first).
49 std::sort(unique.begin(), unique.end());
50
51 // Write that many copies of each unique value to the array.
52 T* HWY_RESTRICT p = values;
53 for (const auto& value_count : unique) {
54 std::fill(p, p + value_count.second, value_count.first);
55 p += value_count.second;
56 }
57 HWY_ASSERT(p == values + num_values);
58}
59
60// @return i in [idx_begin, idx_begin + half_count) that minimizes
61// sorted[i + half_count] - sorted[i].
62template <typename T>
63size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin,
64 const size_t half_count) {
65 T min_range = std::numeric_limits<T>::max();
66 size_t min_idx = 0;
67
68 for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) {
69 HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]);
70 const T range = sorted[idx + half_count] - sorted[idx];
71 if (range < min_range) {
72 min_range = range;
73 min_idx = idx;
74 }
75 }
76
77 return min_idx;
78}
79
80// Returns an estimate of the mode by calling MinRange on successively
81// halved intervals. "sorted" must be in ascending order. This is the
82// Half Sample Mode estimator proposed by Bickel in "On a fast, robust
83// estimator of the mode", with complexity O(N log N). The mode is less
84// affected by outliers in highly-skewed distributions than the median.
85// The averaging operation below assumes "T" is an unsigned integer type.
86template <typename T>
87T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) {
88 size_t idx_begin = 0;
89 size_t half_count = num_values / 2;
90 while (half_count > 1) {
91 idx_begin = MinRange(sorted, idx_begin, half_count);
92 half_count >>= 1;
93 }
94
95 const T x = sorted[idx_begin + 0];
96 if (half_count == 0) {
97 return x;
98 }
99 HWY_ASSERT(half_count == 1);
100 const T average = (x + sorted[idx_begin + 1] + 1) / 2;
101 return average;
102}
103
104// Returns the mode. Side effect: sorts "values".
105template <typename T>
106T Mode(T* values, const size_t num_values) {
107 CountingSort(values, num_values);
108 return ModeOfSorted(values, num_values);
109}
110
111template <typename T, size_t N>
112T Mode(T (&values)[N]) {
113 return Mode(&values[0], N);
114}
115
116// Returns the median value. Side effect: sorts "values".
117template <typename T>
118T Median(T* values, const size_t num_values) {
119 HWY_ASSERT(num_values != 0);
120 std::sort(values, values + num_values);
121 const size_t half = num_values / 2;
122 // Odd count: return middle
123 if (num_values % 2) {
124 return values[half];
125 }
126 // Even count: return average of middle two.
127 return (values[half] + values[half - 1] + 1) / 2;
128}
129
130// Returns a robust measure of variability.
131template <typename T>
132T MedianAbsoluteDeviation(const T* values, const size_t num_values,
133 const T median) {
134 HWY_ASSERT(num_values != 0);
135 std::vector<T> abs_deviations;
136 abs_deviations.reserve(num_values);
137 for (size_t i = 0; i < num_values; ++i) {
138 const int64_t abs = ScalarAbs(static_cast<int64_t>(values[i]) -
139 static_cast<int64_t>(median));
140 abs_deviations.push_back(static_cast<T>(abs));
141 }
142 return Median(abs_deviations.data(), num_values);
143}
144
145} // namespace robust_statistics
146} // namespace hwy
147
148#endif // HIGHWAY_HWY_ROBUST_STATISTICS_H_
#define HWY_RESTRICT
Definition base.h:95
#define HWY_ASSERT(condition)
Definition base.h:237
T MedianAbsoluteDeviation(const T *values, const size_t num_values, const T median)
Definition robust_statistics.h:132
T Median(T *values, const size_t num_values)
Definition robust_statistics.h:118
void CountingSort(T *values, size_t num_values)
Definition robust_statistics.h:32
T Mode(T *values, const size_t num_values)
Definition robust_statistics.h:106
T ModeOfSorted(const T *const HWY_RESTRICT sorted, const size_t num_values)
Definition robust_statistics.h:87
size_t MinRange(const T *const HWY_RESTRICT sorted, const size_t idx_begin, const size_t half_count)
Definition robust_statistics.h:63
Definition abort.h:8
HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef< T > ScalarAbs(T val)
Definition base.h:2815