Skip to content

Commit d636a83

Browse files
author
Jerry Hou
committed
online regression refactor
1 parent 49fb4e6 commit d636a83

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

nvbench/detail/entropy_criterion.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ void entropy_criterion::do_add_measurement(nvbench::float64_t measurement)
104104
}
105105
}
106106

107-
update_entropy_sum(old_count, old_count + 1);
107+
update_entropy_sum(static_cast<nvbench::float64_t>(old_count),
108+
static_cast<nvbench::float64_t>(old_count + 1));
108109
const nvbench::float64_t entropy = compute_entropy();
109110
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_entropy_tracker.size());
110111

nvbench/detail/online_linear_regression.cuh

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,21 @@ public:
104104

105105
[[nodiscard]] nvbench::float64_t slope() const
106106
{
107+
static constexpr nvbench::float_64_t q_nan =
108+
std::numeric_limits<nvbench::float64_t>::quiet_NaN();
109+
107110
if (m_count < 2)
108-
{
109-
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
110-
}
111+
return q_nan;
111112

112113
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_count);
113-
const nvbench::float64_t mean_x = this->mean_x();
114-
const nvbench::float64_t mean_y = this->mean_y();
114+
const nvbench::float64_t mean_x = (m_sum_x / n);
115+
const nvbench::float64_t mean_y = (m_sum_y / n);
115116

116-
const nvbench::float64_t numerator = m_sum_xy - n * mean_x * mean_y;
117-
const nvbench::float64_t denominator = m_sum_x2 - n * mean_x * mean_x;
117+
const nvbench::float64_t numerator = (m_sum_xy / n) - mean_x * mean_y;
118+
const nvbench::float64_t denominator = (m_sum_x2 / n) - mean_x * mean_x;
118119

119-
if (std::abs(denominator) < 1e-9)
120-
{
121-
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
122-
}
120+
if (std::abs(denominator) < 1e-12)
121+
return q_nan;
123122

124123
return numerator / denominator;
125124
}
@@ -148,11 +147,12 @@ public:
148147
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
149148
}
150149

150+
// ss_tot and ss_res scaled by 1/n to avoid overflow
151151
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_count);
152152
const nvbench::float64_t mean_y_v = mean_y();
153-
const nvbench::float64_t ss_tot = m_sum_y2 - n * mean_y_v * mean_y_v;
153+
const nvbench::float64_t ss_tot = (m_sum_y2 / n) - mean_y_v * mean_y_v;
154154

155-
if (ss_tot < 1e-9)
155+
if (ss_tot == 0)
156156
{
157157
return 1.0;
158158
}
@@ -166,10 +166,10 @@ public:
166166
}
167167
else
168168
{
169-
const nvbench::float64_t ss_res = m_sum_y2 - 2.0 * slope_v * m_sum_xy -
170-
2.0 * intercept_v * m_sum_y + slope_v * slope_v * m_sum_x2 +
171-
2.0 * slope_v * intercept_v * m_sum_x +
172-
n * intercept_v * intercept_v;
169+
const nvbench::float64_t ss_res =
170+
(m_sum_y2 / n) - 2.0 * slope_v * (m_sum_xy / n) - 2.0 * intercept_v * (m_sum_y / n) +
171+
slope_v * slope_v * (m_sum_x2 / n) + 2.0 * slope_v * intercept_v * (m_sum_x / n) +
172+
intercept_v * intercept_v;
173173

174174
return std::max(0.0, std::min(1.0, 1.0 - (ss_res / ss_tot)));
175175
}

0 commit comments

Comments
 (0)