|
|
|
@ -72,13 +72,13 @@ class XeGradFunctor {
|
|
|
|
|
size_t num_classes)
|
|
|
|
|
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(size_t label_id) {
|
|
|
|
|
auto x_is_true_offset = label_id * num_classes_ + label_[label_id];
|
|
|
|
|
for (size_t x_offset = label_id * num_classes_;
|
|
|
|
|
x_offset < (label_id + 1) * num_classes_; ++x_offset) {
|
|
|
|
|
HOSTDEVICE void operator()(size_t sample_id) {
|
|
|
|
|
auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
|
|
|
|
|
for (size_t x_offset = sample_id * num_classes_;
|
|
|
|
|
x_offset < (sample_id + 1) * num_classes_; ++x_offset) {
|
|
|
|
|
dx_[x_offset] = x_offset != x_is_true_offset
|
|
|
|
|
? static_cast<T>(0)
|
|
|
|
|
: -dy_[label_id] / x_[x_offset];
|
|
|
|
|
: -dy_[sample_id] / x_[x_offset];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|