Skip to content

Commit f5f13f9

Browse files
author
Jerry Hou
committed
revising ss_tot handling
1 parent d636a83 commit f5f13f9

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

nvbench/detail/online_linear_regression.cuh

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public:
104104

105105
[[nodiscard]] nvbench::float64_t slope() const
106106
{
107-
static constexpr nvbench::float_64_t q_nan =
107+
static constexpr nvbench::float64_t q_nan =
108108
std::numeric_limits<nvbench::float64_t>::quiet_NaN();
109109

110110
if (m_count < 2)
@@ -152,7 +152,7 @@ public:
152152
const nvbench::float64_t mean_y_v = mean_y();
153153
const nvbench::float64_t ss_tot = (m_sum_y2 / n) - mean_y_v * mean_y_v;
154154

155-
if (ss_tot == 0)
155+
if (ss_tot < std::numeric_limits<nvbench::float64_t>::epsilon())
156156
{
157157
return 1.0;
158158
}
@@ -166,12 +166,15 @@ public:
166166
}
167167
else
168168
{
169-
const nvbench::float64_t ss_res =
170-
(m_sum_y2 / n) - 2.0 * slope_v * (m_sum_xy / n) - 2.0 * intercept_v * (m_sum_y / n) +
171-
slope_v * slope_v * (m_sum_x2 / n) + 2.0 * slope_v * intercept_v * (m_sum_x / n) +
172-
intercept_v * intercept_v;
173-
174-
return std::max(0.0, std::min(1.0, 1.0 - (ss_res / ss_tot)));
169+
const nvbench::float64_t mean_xy_v = m_sum_xy / n;
170+
const nvbench::float64_t mean_xx_v = m_sum_x2 / n;
171+
const nvbench::float64_t mean_x_v = m_sum_x / n;
172+
const nvbench::float64_t ss_tot_m_res =
173+
slope_v * ((mean_xy_v - slope_v * mean_xx_v) + (mean_xy_v - intercept_v * mean_x_v)) +
174+
intercept_v * (mean_y_v - slope_v * mean_x_v - intercept_v) +
175+
mean_y_v * (intercept_v - mean_y_v);
176+
177+
return std::min(std::max(ss_tot_m_res / ss_tot, 0.0), 1.0);
175178
}
176179
}
177180

0 commit comments

Comments
 (0)