|
|
|
@ -186,9 +186,9 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class TileGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* in0 = context.Input<Tensor>("X");
|
|
|
|
|
auto* x = context.Input<Tensor>("X");
|
|
|
|
|
auto repeat_times = get_repeat_times(context);
|
|
|
|
|
auto x_dims = in0->dims();
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto vec_in_dims = framework::vectorize<int>(x_dims);
|
|
|
|
|
if (repeat_times.size() < vec_in_dims.size()) {
|
|
|
|
|
int diff = vec_in_dims.size() - repeat_times.size();
|
|
|
|
@ -220,11 +220,13 @@ class TileGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
// no need reduce, just copy
|
|
|
|
|
if (just_copy) {
|
|
|
|
|
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
out0->mutable_data<T>(context.GetPlace());
|
|
|
|
|
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
|
|
|
|
|
out0);
|
|
|
|
|
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
dx->mutable_data<T>(context.GetPlace());
|
|
|
|
|
framework::TensorCopy(*dout, context.GetPlace(), context.device_context(),
|
|
|
|
|
dx);
|
|
|
|
|
// TensorCopy may change the dims of dx
|
|
|
|
|
dx->Resize(x_dims);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GE(dims, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
@ -261,6 +263,7 @@ class TileGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t i = 0; i < reduce_size; ++i) {
|
|
|
|
|
reduce_dims[i] = reduce_dims_vec[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_grad = EigenVector<T>::Flatten(*in0);
|
|
|
|
|
x_grad.device(
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device()) =
|
|
|
|
|