|
|
@ -10,6 +10,7 @@
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes<int64_t, 1>;
|
|
|
|
using Array2 = Eigen::DSizes<int64_t, 2>;
|
|
|
|
using Array2 = Eigen::DSizes<int64_t, 2>;
|
|
|
|
using IndexPair = Eigen::IndexPair<int>;
|
|
|
|
using IndexPair = Eigen::IndexPair<int>;
|
|
|
|
|
|
|
|
|
|
|
|
static inline void CalcMatrixShape(const Tensor& weight, const int dim, int* h,
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
int* w) {
|
|
|
|
static inline void TransCompute(const int rank, const Tensor& in, Tensor* out,
|
|
|
|
auto weight_dims = weight.dims();
|
|
|
|
const std::vector<int>& perm,
|
|
|
|
*h = 1;
|
|
|
|
const DeviceContext& dev_ctx) {
|
|
|
|
*w = 1;
|
|
|
|
if (rank <= 1 || rank > 5) {
|
|
|
|
for (int i = 0; i < weight_dims.size(); i++) {
|
|
|
|
PADDLE_THROW("Invalid weight rank.");
|
|
|
|
if (i <= dim) {
|
|
|
|
}
|
|
|
|
*h *= weight_dims[i];
|
|
|
|
|
|
|
|
} else {
|
|
|
|
switch (rank) {
|
|
|
|
*w *= weight_dims[i];
|
|
|
|
case 2:
|
|
|
|
}
|
|
|
|
math::Transpose<DeviceContext, T, 2> trans2;
|
|
|
|
|
|
|
|
trans2(dev_ctx, in, out, perm);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case 3:
|
|
|
|
|
|
|
|
math::Transpose<DeviceContext, T, 3> trans3;
|
|
|
|
|
|
|
|
trans3(dev_ctx, in, out, perm);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case 4:
|
|
|
|
|
|
|
|
math::Transpose<DeviceContext, T, 4> trans4;
|
|
|
|
|
|
|
|
trans4(dev_ctx, in, out, perm);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case 5:
|
|
|
|
|
|
|
|
math::Transpose<DeviceContext, T, 5> trans5;
|
|
|
|
|
|
|
|
trans5(dev_ctx, in, out, perm);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -83,6 +100,7 @@ template <typename DeviceContext, typename T>
|
|
|
|
class SpectralNormKernel : public framework::OpKernel<T> {
|
|
|
|
class SpectralNormKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
auto weight = ctx.Input<Tensor>("Weight");
|
|
|
|
auto weight = ctx.Input<Tensor>("Weight");
|
|
|
|
auto u = ctx.Input<Tensor>("U");
|
|
|
|
auto u = ctx.Input<Tensor>("U");
|
|
|
|
auto v = ctx.Input<Tensor>("V");
|
|
|
|
auto v = ctx.Input<Tensor>("V");
|
|
|
@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel<T> {
|
|
|
|
int power_iters = ctx.Attr<int>("power_iters");
|
|
|
|
int power_iters = ctx.Attr<int>("power_iters");
|
|
|
|
float eps = ctx.Attr<float>("eps");
|
|
|
|
float eps = ctx.Attr<float>("eps");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int h = u->dims()[0];
|
|
|
|
|
|
|
|
const int w = v->dims()[0];
|
|
|
|
|
|
|
|
|
|
|
|
Tensor weight_mat;
|
|
|
|
Tensor weight_mat;
|
|
|
|
int h, w;
|
|
|
|
auto dims = weight->dims();
|
|
|
|
CalcMatrixShape(*weight, dim, &h, &w);
|
|
|
|
const int rank = dims.size();
|
|
|
|
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
|
|
|
|
std::vector<int> real_dims;
|
|
|
|
|
|
|
|
if (dim != 0) {
|
|
|
|
|
|
|
|
std::vector<int> perm;
|
|
|
|
|
|
|
|
perm.push_back(dim);
|
|
|
|
|
|
|
|
real_dims.push_back(dims[dim]);
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
if (i != dim) {
|
|
|
|
|
|
|
|
perm.push_back(i);
|
|
|
|
|
|
|
|
real_dims.push_back(dims[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
weight_mat.mutable_data<T>(framework::make_ddim(real_dims),
|
|
|
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
real_dims.push_back(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
|
|
|
|
|
|
|
|
}
|
|
|
|
weight_mat = weight_mat.Resize({h, w});
|
|
|
|
weight_mat = weight_mat.Resize({h, w});
|
|
|
|
|
|
|
|
|
|
|
|
Tensor sigma;
|
|
|
|
Tensor sigma;
|
|
|
@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel<T> {
|
|
|
|
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
|
|
|
|
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
|
|
|
|
&sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat,
|
|
|
|
&sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat,
|
|
|
|
power_iters, eps, ctx);
|
|
|
|
power_iters, eps, ctx);
|
|
|
|
TensorCopySync(weight_mat.Resize(out->dims()), ctx.GetPlace(), out);
|
|
|
|
|
|
|
|
|
|
|
|
if (dim != 0) {
|
|
|
|
|
|
|
|
std::vector<int> perm;
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
if (i < dim) {
|
|
|
|
|
|
|
|
perm.push_back(i + 1);
|
|
|
|
|
|
|
|
} else if (i == dim) {
|
|
|
|
|
|
|
|
perm.push_back(0);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
perm.push_back(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
out->mutable_data<T>(dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
TransCompute<DeviceContext, T>(
|
|
|
|
|
|
|
|
rank, weight_mat.Resize(framework::make_ddim(real_dims)), out, perm,
|
|
|
|
|
|
|
|
dev_ctx);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
TensorCopySync(weight_mat.Resize(dims), ctx.GetPlace(), out);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
auto weight = ctx.Input<Tensor>("Weight");
|
|
|
|
auto weight = ctx.Input<Tensor>("Weight");
|
|
|
|
auto u = ctx.Input<Tensor>("U");
|
|
|
|
auto u = ctx.Input<Tensor>("U");
|
|
|
@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
|
|
|
|
int power_iters = ctx.Attr<int>("power_iters");
|
|
|
|
int power_iters = ctx.Attr<int>("power_iters");
|
|
|
|
float eps = ctx.Attr<float>("eps");
|
|
|
|
float eps = ctx.Attr<float>("eps");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int h = u->dims()[0];
|
|
|
|
|
|
|
|
const int w = v->dims()[0];
|
|
|
|
|
|
|
|
|
|
|
|
Tensor weight_mat, out_grad_mat;
|
|
|
|
Tensor weight_mat, out_grad_mat;
|
|
|
|
int h, w;
|
|
|
|
auto dims = weight->dims();
|
|
|
|
CalcMatrixShape(*weight, dim, &h, &w);
|
|
|
|
const int rank = dims.size();
|
|
|
|
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
|
|
|
|
std::vector<int> real_dims;
|
|
|
|
TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat);
|
|
|
|
if (dim != 0) {
|
|
|
|
|
|
|
|
std::vector<int> perm;
|
|
|
|
|
|
|
|
perm.push_back(dim);
|
|
|
|
|
|
|
|
real_dims.push_back(dims[dim]);
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
if (i != dim) {
|
|
|
|
|
|
|
|
perm.push_back(i);
|
|
|
|
|
|
|
|
real_dims.push_back(dims[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
weight_mat.mutable_data<T>(framework::make_ddim(real_dims),
|
|
|
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
|
|
|
out_grad_mat.mutable_data<T>(framework::make_ddim(real_dims),
|
|
|
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
|
|
|
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, *out_grad, &out_grad_mat, perm,
|
|
|
|
|
|
|
|
dev_ctx);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
real_dims.push_back(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
|
|
|
|
|
|
|
|
TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat);
|
|
|
|
|
|
|
|
}
|
|
|
|
weight_mat = weight_mat.Resize({h, w});
|
|
|
|
weight_mat = weight_mat.Resize({h, w});
|
|
|
|
out_grad_mat = out_grad_mat.Resize({h, w});
|
|
|
|
out_grad_mat = out_grad_mat.Resize({h, w});
|
|
|
|
|
|
|
|
|
|
|
@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
|
|
|
|
blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv,
|
|
|
|
blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv,
|
|
|
|
T(0));
|
|
|
|
T(0));
|
|
|
|
|
|
|
|
|
|
|
|
Tensor weight_grad_mat, ones;
|
|
|
|
Tensor weight_grad_mat;
|
|
|
|
weight_grad_mat.mutable_data<T>({h, w}, ctx.GetPlace());
|
|
|
|
weight_grad_mat.mutable_data<T>({h, w}, ctx.GetPlace());
|
|
|
|
ones.mutable_data<T>({h, w}, ctx.GetPlace());
|
|
|
|
|
|
|
|
auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat);
|
|
|
|
auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat);
|
|
|
|
auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat);
|
|
|
|
auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat);
|
|
|
|
auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat);
|
|
|
|
auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat);
|
|
|
|
auto sigma_t = EigenTensor<T, 2>::From(sigma);
|
|
|
|
auto sigma_t = EigenTensor<T, 2>::From(sigma);
|
|
|
|
auto uv_t = EigenTensor<T, 2>::From(uv);
|
|
|
|
auto uv_t = EigenTensor<T, 2>::From(uv);
|
|
|
|
auto ones_t = EigenTensor<T, 2>::From(ones).setConstant((T)1);
|
|
|
|
|
|
|
|
weight_mat_t.device(place) =
|
|
|
|
weight_mat_t.device(place) =
|
|
|
|
weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w));
|
|
|
|
weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w));
|
|
|
|
weight_grad_mat_t.device(place) =
|
|
|
|
weight_grad_mat_t.device(place) =
|
|
|
|
out_grad_mat_t * (ones_t - uv_t * weight_mat_t) / sigma_t;
|
|
|
|
out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) /
|
|
|
|
TensorCopySync(weight_grad_mat.Resize(weight_grad->dims()), ctx.GetPlace(),
|
|
|
|
sigma_t;
|
|
|
|
weight_grad);
|
|
|
|
|
|
|
|
|
|
|
|
if (dim != 0) {
|
|
|
|
|
|
|
|
std::vector<int> perm;
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
|
|
|
if (i < dim) {
|
|
|
|
|
|
|
|
perm.push_back(i + 1);
|
|
|
|
|
|
|
|
} else if (i == dim) {
|
|
|
|
|
|
|
|
perm.push_back(0);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
perm.push_back(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
weight_grad->mutable_data<T>(dims, ctx.GetPlace());
|
|
|
|
|
|
|
|
TransCompute<DeviceContext, T>(
|
|
|
|
|
|
|
|
rank, weight_grad_mat.Resize(framework::make_ddim(real_dims)),
|
|
|
|
|
|
|
|
weight_grad, perm, dev_ctx);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
TensorCopySync(weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|