|
|
|
@ -86,10 +86,10 @@ class XeGradFunctor {
|
|
|
|
|
auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
|
|
|
|
|
for (size_t x_offset = sample_id * num_classes_;
|
|
|
|
|
x_offset < (sample_id + 1) * num_classes_; ++x_offset) {
|
|
|
|
|
dx_[x_offset] =
|
|
|
|
|
(x_offset != x_is_true_offset || label_[sample_id] == ignore_index_)
|
|
|
|
|
? static_cast<T>(0)
|
|
|
|
|
: -dy_[sample_id] / x_[x_offset];
|
|
|
|
|
dx_[x_offset] = (x_offset != x_is_true_offset ||
|
|
|
|
|
label_[sample_id] == static_cast<int64_t>(ignore_index_))
|
|
|
|
|
? static_cast<T>(0)
|
|
|
|
|
: -dy_[sample_id] / x_[x_offset];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|