You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
5.7 KiB
169 lines
5.7 KiB
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
typename IndexType = Eigen::DenseIndex>
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
template <typename DeviceContext, typename T>
|
|
class DotKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* tensor_x = ctx.Input<Tensor>("X");
|
|
auto* tensor_y = ctx.Input<Tensor>("Y");
|
|
auto* tensor_out = ctx.Output<Tensor>("Out");
|
|
tensor_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
#ifdef __NVCC__
|
|
if (1 == tensor_out->dims().size()) {
|
|
auto out = framework::EigenScalar<T>::From(*tensor_out);
|
|
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
|
|
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
|
|
|
|
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
out.device(dev) = (x * y).sum();
|
|
} else {
|
|
auto out = EigenMatrix<T>::From(*tensor_out);
|
|
auto x = EigenMatrix<T>::From(*tensor_x);
|
|
auto y = EigenMatrix<T>::From(*tensor_y);
|
|
|
|
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
|
|
}
|
|
#else
|
|
const auto* data_x = tensor_x->data<T>();
|
|
const auto* data_y = tensor_y->data<T>();
|
|
auto* data_out = tensor_out->data<T>();
|
|
|
|
auto x_dims = tensor_x->dims();
|
|
auto step = x_dims[x_dims.size() - 1];
|
|
int size = static_cast<int>(framework::product(x_dims));
|
|
|
|
for (int ind = -1, j = 0; j < size; ++j) {
|
|
if (j % step == 0) {
|
|
++ind;
|
|
data_out[ind] = data_x[j] * data_y[j];
|
|
} else {
|
|
data_out[ind] += data_x[j] * data_y[j];
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <typename DeviceContext, typename T>
|
|
class DotGradKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* tensor_x = ctx.Input<Tensor>("X");
|
|
auto* tensor_y = ctx.Input<Tensor>("Y");
|
|
auto* tensor_dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
auto* tensor_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
auto* tensor_dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
|
|
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
|
|
#ifdef __NVCC__
|
|
if (1 == tensor_dout->dims().size()) {
|
|
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
|
|
|
|
if (tensor_dx) {
|
|
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
|
|
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
|
|
auto& dev =
|
|
*ctx.template device_context<DeviceContext>().eigen_device();
|
|
Eigen::DSizes<int, 1> size(tensor_dx->numel());
|
|
dx.device(dev) = y * dout.broadcast(size);
|
|
}
|
|
|
|
if (tensor_dy) {
|
|
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
|
|
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
|
|
auto& dev =
|
|
*ctx.template device_context<DeviceContext>().eigen_device();
|
|
Eigen::DSizes<int, 1> size(tensor_dy->numel());
|
|
dy.device(dev) = x * dout.broadcast(size);
|
|
}
|
|
} else {
|
|
auto dout = EigenMatrix<T>::From(*tensor_dout);
|
|
|
|
if (tensor_dx) {
|
|
tensor_dx->mutable_data<T>(ctx.GetPlace());
|
|
auto y = EigenMatrix<T>::From(*tensor_y);
|
|
auto dx = EigenMatrix<T>::From(*tensor_dx);
|
|
auto& dev =
|
|
*ctx.template device_context<DeviceContext>().eigen_device();
|
|
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
|
|
dx.device(dev) = y * dout.broadcast(size);
|
|
}
|
|
|
|
if (tensor_dy) {
|
|
tensor_dy->mutable_data<T>(ctx.GetPlace());
|
|
auto x = EigenMatrix<T>::From(*tensor_x);
|
|
auto dy = EigenMatrix<T>::From(*tensor_dy);
|
|
auto& dev =
|
|
*ctx.template device_context<DeviceContext>().eigen_device();
|
|
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
|
|
dy.device(dev) = x * dout.broadcast(size);
|
|
}
|
|
}
|
|
#else
|
|
const auto* data_dout = tensor_dout->data<T>();
|
|
|
|
if (tensor_dx) {
|
|
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
|
|
const auto* data_y = tensor_y->data<T>();
|
|
const framework::DDim& dim = tensor_x->dims();
|
|
size_t N = static_cast<size_t>(framework::product(dim));
|
|
|
|
auto step = dim[dim.size() - 1];
|
|
|
|
int s = -1;
|
|
for (size_t i = 0; i < N; ++i) {
|
|
if (0 == i % step) ++s;
|
|
data_dx[i] = data_y[i] * data_dout[s];
|
|
}
|
|
}
|
|
|
|
if (tensor_dy) {
|
|
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
|
|
const auto* data_x = tensor_x->data<T>();
|
|
const framework::DDim& dim = tensor_y->dims();
|
|
size_t N = static_cast<size_t>(framework::product(dim));
|
|
|
|
auto step = dim[dim.size() - 1];
|
|
|
|
int s = -1;
|
|
for (size_t i = 0; i < N; ++i) {
|
|
if (0 == i % step) ++s;
|
|
data_dy[i] = data_x[i] * data_dout[s];
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|