|
|
@ -32,9 +32,9 @@ struct HuberLossForward {
|
|
|
|
HOSTDEVICE T operator()(const T& val) const {
|
|
|
|
HOSTDEVICE T operator()(const T& val) const {
|
|
|
|
T abs_val = std::abs(val);
|
|
|
|
T abs_val = std::abs(val);
|
|
|
|
if (abs_val <= delta) {
|
|
|
|
if (abs_val <= delta) {
|
|
|
|
return 0.5 * val * val;
|
|
|
|
return static_cast<T>(0.5) * val * val;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
return delta * (abs_val - 0.5 * delta);
|
|
|
|
return delta * (abs_val - static_cast<T>(0.5) * delta);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|