|
|
|
@ -147,11 +147,14 @@ struct KronGradElemFunctor {
|
|
|
|
|
index_b += stride_b_[i] * pos_bi;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t index_out_a = index_a * numel_b_ + index_b;
|
|
|
|
|
size_t index_out_b = index_b * numel_a_ + index_a;
|
|
|
|
|
|
|
|
|
|
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
|
|
|
|
|
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
|
|
|
|
|
if (dout_a_) {
|
|
|
|
|
size_t index_out_a = index_a * numel_b_ + index_b;
|
|
|
|
|
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
|
|
|
|
|
}
|
|
|
|
|
if (dout_b_) {
|
|
|
|
|
size_t index_out_b = index_b * numel_a_ + index_a;
|
|
|
|
|
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -222,35 +225,50 @@ struct KronGradOpFunctor {
|
|
|
|
|
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
|
|
|
|
|
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
|
|
|
|
|
framework::Tensor dout_x;
|
|
|
|
|
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
|
|
|
|
|
T* p_dout_x = nullptr;
|
|
|
|
|
if (dx) {
|
|
|
|
|
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
|
|
|
|
|
p_dout_x = dout_x.data<T>();
|
|
|
|
|
}
|
|
|
|
|
framework::Tensor dout_y;
|
|
|
|
|
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
|
|
|
|
|
T* p_dout_y = nullptr;
|
|
|
|
|
if (dy) {
|
|
|
|
|
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
|
|
|
|
|
p_dout_y = dout_y.data<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
|
|
|
|
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
|
|
|
|
|
dout_x.data<T>(), dout_y.data<T>(),
|
|
|
|
|
p_stride_dout, p_stride_x, p_stride_y,
|
|
|
|
|
p_shape_y, numel_x, numel_y, ndims);
|
|
|
|
|
p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
|
|
|
|
|
p_stride_y, p_shape_y, numel_x, numel_y, ndims);
|
|
|
|
|
for_range(func);
|
|
|
|
|
|
|
|
|
|
// reduce_sum along aixs 1
|
|
|
|
|
#if __NVCC__
|
|
|
|
|
auto stream = dev_ctx.stream(); // it is a cuda device_context
|
|
|
|
|
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
|
|
|
|
|
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
|
|
|
|
|
stream);
|
|
|
|
|
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
|
|
|
|
|
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
|
|
|
|
|
stream);
|
|
|
|
|
if (dx) {
|
|
|
|
|
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
|
|
|
|
|
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
|
|
|
|
|
stream);
|
|
|
|
|
}
|
|
|
|
|
if (dy) {
|
|
|
|
|
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
|
|
|
|
|
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
|
|
|
|
|
stream);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
|
|
|
|
|
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
|
|
|
|
|
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
|
|
|
|
|
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
|
|
|
|
|
auto* place = dev_ctx.eigen_device();
|
|
|
|
|
Eigen::array<int, 1> reduce_dim = {1};
|
|
|
|
|
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
|
|
|
|
|
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
|
|
|
|
|
if (dx) {
|
|
|
|
|
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
|
|
|
|
|
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
|
|
|
|
|
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
|
|
|
|
|
}
|
|
|
|
|
if (dy) {
|
|
|
|
|
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
|
|
|
|
|
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
|
|
|
|
|
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (dx) {
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
if (dy) {
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ndims = dout->dims().size();
|
|
|
|
|
framework::Tensor xx = UnsqueezeTo(*x, ndims);
|
|
|
|
|
framework::Tensor dxx = UnsqueezeTo(*dx, ndims);
|
|
|
|
|
framework::Tensor yy = UnsqueezeTo(*y, ndims);
|
|
|
|
|
framework::Tensor dyy = UnsqueezeTo(*dy, ndims);
|
|
|
|
|
|
|
|
|
|
framework::Tensor* pdxx = nullptr;
|
|
|
|
|
framework::Tensor* pdyy = nullptr;
|
|
|
|
|
framework::Tensor dxx;
|
|
|
|
|
framework::Tensor dyy;
|
|
|
|
|
if (dx) {
|
|
|
|
|
dxx = UnsqueezeTo(*dx, ndims);
|
|
|
|
|
pdxx = &dxx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
dyy = UnsqueezeTo(*dy, ndims);
|
|
|
|
|
pdyy = &dyy;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KronGradOpFunctor<DeviceContext, T> func;
|
|
|
|
|
func(dev_ctx, *dout, xx, yy, &dxx, &dyy);
|
|
|
|
|
func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|