You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/bmm_op.h

142 lines
4.8 KiB

/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_
#define PADDLE_FLUID_OPERATORS_BMM_OP_H_
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor *x, const math::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
x->Resize({descriptor.batch_size_, h, w});
}
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor *y,
framework::Tensor *out, bool trans_x,
bool trans_y) {
auto x_dim = x->dims();
auto y_dim = y->dims();
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, false);
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, false);
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
class BmmKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor &x = *context.Input<Tensor>("X");
const Tensor &y = *context.Input<Tensor>("Y");
Tensor *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(x.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(y.dims(), 0, false);
// auto scale = static_cast<T>(context.Attr<float>("alpha"));
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
}
};
template <typename DeviceContext, typename T>
class BmmGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
}
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
if (out == nullptr) return;
MatMul(context, a, trans_a, b, trans_b, out);
}
void Compute(const framework::ExecutionContext &context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
CalcInputGrad(context, dout, false, y, true, dx);
CalcInputGrad(context, x, true, dout, false, dy);
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
}
};
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_