|
|
@ -30,7 +30,7 @@ struct StackGradFunctor {
|
|
|
|
int i = idx / (n_ * post_);
|
|
|
|
int i = idx / (n_ * post_);
|
|
|
|
int which_x = idx / post_ - i * n_;
|
|
|
|
int which_x = idx / post_ - i * n_;
|
|
|
|
int x_index = i * post_ + idx % post_;
|
|
|
|
int x_index = i * post_ + idx % post_;
|
|
|
|
dx_[which_x][x_index] = dy_[idx];
|
|
|
|
if (dx_[which_x] != nullptr) dx_[which_x][x_index] = dy_[idx];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
@ -95,19 +95,21 @@ class StackGradKernel : public framework::OpKernel<T> {
|
|
|
|
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
if (axis < 0) axis += dy->dims().size();
|
|
|
|
if (axis < 0) axis += dy->dims().size();
|
|
|
|
|
|
|
|
|
|
|
|
int n = dy->dims()[axis];
|
|
|
|
int n = dy->dims()[axis];
|
|
|
|
std::vector<T *> dx_datas(n); // NOLINT
|
|
|
|
std::vector<T *> dx_datas(n); // NOLINT
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
if (dx[i] == nullptr) {
|
|
|
|
|
|
|
|
dx_datas[i] = nullptr;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto dy_data = dy->data<T>();
|
|
|
|
auto dy_data = dy->data<T>();
|
|
|
|
|
|
|
|
|
|
|
|
int pre = 1;
|
|
|
|
int pre = 1;
|
|
|
|
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
|
|
|
|
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
|
|
|
|
int total_num = dy->numel();
|
|
|
|
int total_num = dy->numel();
|
|
|
|
int post = total_num / (n * pre);
|
|
|
|
int post = total_num / (n * pre);
|
|
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
auto dx_data_arr = dx_datas.data();
|
|
|
|
auto dx_data_arr = dx_datas.data();
|
|
|
|
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
|
|
|
|
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
|
|
|
|