You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/kron_op.h

360 lines
12 KiB

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#if __NVCC__
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "thrust/device_vector.h"
#endif
namespace paddle {
namespace operators {
// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
KronElemFunctor(const T* a, const T* b, T* out, const int64_t* shape_b,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* stride_out, int ndims)
: a_(a),
b_(b),
out_(out),
shape_b_(shape_b),
stride_a_(stride_a),
stride_b_(stride_b),
stride_out_(stride_out),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) const {
// it computes 1 element in the output
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_out_[i];
index = index % stride_out_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
out_[idx] = a_[index_a] * b_[index_b];
}
private:
const T* a_;
const T* b_;
T* out_;
const int64_t* shape_b_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* stride_out_;
const int ndims_;
};
template <typename DeviceContext, typename T>
struct KronOpFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor& x,
const framework::Tensor& y, framework::Tensor* out) {
int ndims = out->dims().size();
int64_t numel = out->numel();
const framework::DDim& dim_x = x.dims();
const framework::DDim& dim_y = y.dims();
const framework::DDim& dim_out = out->dims();
const framework::DDim stride_x = framework::stride(dim_x);
const framework::DDim stride_y = framework::stride(dim_y);
const framework::DDim stride_out = framework::stride(dim_out);
const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr;
#if __NVCC__
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
thrust::device_vector<int64_t> d_stride_out(ndims);
thrust::device_vector<int64_t> d_shape_y(ndims);
thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin());
thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin());
thrust::copy(stride_out.Get(), stride_out.Get() + ndims,
d_stride_out.begin());
thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin());
p_stride_x = thrust::raw_pointer_cast(d_stride_x.data());
p_stride_y = thrust::raw_pointer_cast(d_stride_y.data());
p_stride_out = thrust::raw_pointer_cast(d_stride_out.data());
p_shape_y = thrust::raw_pointer_cast(d_shape_y.data());
#else
p_stride_x = stride_x.Get();
p_stride_y = stride_y.Get();
p_stride_out = stride_out.Get();
p_shape_y = dim_y.Get();
#endif
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronElemFunctor<T> functor(x.data<T>(), y.data<T>(), out->data<T>(),
p_shape_y, p_stride_x, p_stride_y, p_stride_out,
ndims);
for_range(functor);
}
};
template <typename T>
struct KronGradElemFunctor {
KronGradElemFunctor(const T* dout, const T* A, const T* B, T* dout_a,
T* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
}
}
private:
const T* dout_;
const T* A_;
const T* B_;
T* dout_a_;
T* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
template <typename DeviceContext, typename T>
struct KronGradOpFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout,
const framework::Tensor& x, const framework::Tensor& y,
framework::Tensor* dx, framework::Tensor* dy) {
int ndims = dout.dims().size();
int64_t numel = dout.numel();
int64_t numel_x = x.numel();
int64_t numel_y = y.numel();
const framework::DDim& dim_x = x.dims();
const framework::DDim& dim_y = y.dims();
const framework::DDim& dim_dout = dout.dims();
const framework::DDim stride_x = framework::stride(dim_x);
const framework::DDim stride_y = framework::stride(dim_y);
const framework::DDim stride_dout = framework::stride(dim_dout);
const int64_t* p_stride_x = nullptr;
const int64_t* p_stride_y = nullptr;
const int64_t* p_stride_dout = nullptr;
const int64_t* p_shape_y = nullptr;
#if __NVCC__
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
thrust::device_vector<int64_t> d_stride_dout(ndims);
thrust::device_vector<int64_t> d_shape_y(ndims);
thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin());
thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin());
thrust::copy(stride_dout.Get(), stride_dout.Get() + ndims,
d_stride_dout.begin());
thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin());
p_stride_x = thrust::raw_pointer_cast(d_stride_x.data());
p_stride_y = thrust::raw_pointer_cast(d_stride_y.data());
p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data());
p_shape_y = thrust::raw_pointer_cast(d_shape_y.data());
#else
p_stride_x = stride_x.Get();
p_stride_y = stride_y.Get();
p_stride_dout = stride_dout.Get();
p_shape_y = dim_y.Get();
#endif
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
framework::Tensor dout_x;
T* p_dout_x = nullptr;
if (dx) {
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
p_dout_x = dout_x.data<T>();
}
framework::Tensor dout_y;
T* p_dout_y = nullptr;
if (dy) {
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
p_dout_y = dout_y.data<T>();
}
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
p_stride_y, p_shape_y, numel_x, numel_y, ndims);
for_range(func);
// reduce_sum along aixs 1
#if __NVCC__
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
stream);
}
if (dy) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
stream);
}
#else
auto* place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1};
if (dx) {
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
}
if (dy) {
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
}
#endif
}
};
inline framework::Tensor UnsqueezeTo(const framework::Tensor& src, int ndims) {
const framework::DDim& shape = src.dims();
int rank = shape.size();
framework::Tensor res;
res.ShareDataWith(src);
PADDLE_ENFORCE_LE(
rank, ndims,
platform::errors::InvalidArgument(
"The input Tensor's rank should be less than or equal to ndims"
"Received input Tensor's rank = %d, ndims = %d",
rank, ndims));
if (rank < ndims) {
std::vector<int64_t> new_dim(ndims, 1);
for (int i = ndims - rank; i < ndims; i++) {
new_dim[i] = shape[i - ndims + rank];
}
res.Resize(framework::make_ddim(new_dim));
}
return res;
}
template <typename DeviceContext, typename T>
class KronKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int ndims = out->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims);
KronOpFunctor<DeviceContext, T> func;
func(dev_ctx, xx, yy, out);
}
};
template <typename DeviceContext, typename T>
class KronGradKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
int ndims = dout->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims);
framework::Tensor* pdxx = nullptr;
framework::Tensor* pdyy = nullptr;
framework::Tensor dxx;
framework::Tensor dyy;
if (dx) {
dxx = UnsqueezeTo(*dx, ndims);
pdxx = &dxx;
}
if (dy) {
dyy = UnsqueezeTo(*dy, ndims);
pdyy = &dyy;
}
KronGradOpFunctor<DeviceContext, T> func;
func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
}
};
} // namespace operators
} // namespace paddle