|
|
|
@ -63,7 +63,7 @@ void MaximumGradRecTask(T *x, T *y, T *dout, T *dx, T *dy, size_t dim, size_t x_
|
|
|
|
|
size_t dout_i = i * dout_cargo[dim];
|
|
|
|
|
|
|
|
|
|
if (dim == dout_shape.size() - 1) {
|
|
|
|
|
if (*(x + x_index + x_i) >= *(y + y_index + y_i)) {
|
|
|
|
|
if (*(x + x_index + x_i) > *(y + y_index + y_i)) {
|
|
|
|
|
*(dx + x_index + x_i) += *(dout + dout_index + i);
|
|
|
|
|
} else {
|
|
|
|
|
*(dy + y_index + y_i) += *(dout + dout_index + i);
|
|
|
|
|