parent
							
								
									7b5e23c034
								
							
						
					
					
						commit
						1c08a2136e
					
				@ -0,0 +1,236 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/addmm_op.h"
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#ifdef PADDLE_WITH_MKLDNN
 | 
				
			||||
#include "paddle/fluid/platform/mkldnn_helper.h"
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using framework::OpKernelType;
 | 
				
			||||
using framework::Tensor;
 | 
				
			||||
 | 
				
			||||
class AddMMOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
 | 
				
			||||
                      platform::errors::NotFound(
 | 
				
			||||
                          "Input(Input) of AddMMOp should not be null."));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput("X"), true,
 | 
				
			||||
        platform::errors::NotFound("Input(X) of AddMMOp should not be null."));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput("Y"), true,
 | 
				
			||||
        platform::errors::NotFound("Input(Y) of AddMMOp should not be null."));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
 | 
				
			||||
                      platform::errors::NotFound(
 | 
				
			||||
                          "Output(Out) of AddMMOp should not be null."));
 | 
				
			||||
 | 
				
			||||
    auto input_dims = ctx->GetInputDim("Input");
 | 
				
			||||
    auto x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    auto y_dims = ctx->GetInputDim("Y");
 | 
				
			||||
 | 
				
			||||
    auto ndim_input = input_dims.size();
 | 
				
			||||
    auto ndim_x = x_dims.size();
 | 
				
			||||
    auto ndim_y = y_dims.size();
 | 
				
			||||
 | 
				
			||||
    float alpha = ctx->Attrs().Get<float>("Alpha");
 | 
				
			||||
    float beta = ctx->Attrs().Get<float>("Beta");
 | 
				
			||||
 | 
				
			||||
    VLOG(3) << "addmm operator input.shape=" << input_dims
 | 
				
			||||
            << " x.shape=" << x_dims << " y.shape=" << y_dims
 | 
				
			||||
            << " beta=" << beta << " alpha=" << alpha
 | 
				
			||||
            << " ndim_input=" << ndim_input << " ndim_x=" << ndim_x
 | 
				
			||||
            << " ndim_y=" << ndim_y;
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_NE(framework::product(input_dims), 0,
 | 
				
			||||
                      platform::errors::PreconditionNotMet(
 | 
				
			||||
                          "The Input variable Input(%s) has not "
 | 
				
			||||
                          "been initialized. You may need to confirm "
 | 
				
			||||
                          "if you put exe.run(startup_program) "
 | 
				
			||||
                          "after optimizer.minimize function.",
 | 
				
			||||
                          ctx->Inputs("Input").front()));
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
 | 
				
			||||
                      platform::errors::PreconditionNotMet(
 | 
				
			||||
                          "The Input variable X(%s) has not "
 | 
				
			||||
                          "been initialized. You may need to confirm "
 | 
				
			||||
                          "if you put exe.run(startup_program) "
 | 
				
			||||
                          "after optimizer.minimize function.",
 | 
				
			||||
                          ctx->Inputs("X").front()));
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
 | 
				
			||||
                      platform::errors::PreconditionNotMet(
 | 
				
			||||
                          "The Input variable Y(%s) has not "
 | 
				
			||||
                          "been initialized. You may need to confirm "
 | 
				
			||||
                          "if you put exe.run(startup_program) "
 | 
				
			||||
                          "after optimizer.minimize function.",
 | 
				
			||||
                          ctx->Inputs("Y").front()));
 | 
				
			||||
    // dim check
 | 
				
			||||
    PADDLE_ENFORCE_EQ(ndim_input, 2,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The input tensor input's dimension must be 2. "
 | 
				
			||||
                          "But received input's dimension = [%s].",
 | 
				
			||||
                          ndim_input));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(ndim_x, 2,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The input tensor x's dimension must be 2. "
 | 
				
			||||
                          "But received x's dimension = [%s].",
 | 
				
			||||
                          ndim_x));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(ndim_y, 2,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The input tensor y's dimension must be 2. "
 | 
				
			||||
                          "But received y's dimension = [%s].",
 | 
				
			||||
                          ndim_y));
 | 
				
			||||
 | 
				
			||||
    std::vector<int64_t> output_dims;
 | 
				
			||||
    output_dims.push_back(x_dims[0]);
 | 
				
			||||
    output_dims.push_back(y_dims[1]);
 | 
				
			||||
 | 
				
			||||
    ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
 | 
				
			||||
    ctx->ShareLoD("Input", /*->*/ "Out");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const {
 | 
				
			||||
    framework::LibraryType library = framework::LibraryType::kPlain;
 | 
				
			||||
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
 | 
				
			||||
    int customized_type_value =
 | 
				
			||||
        framework::OpKernelType::kDefaultCustomizedTypeValue;
 | 
				
			||||
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
 | 
				
			||||
#ifdef PADDLE_WITH_MKLDNN
 | 
				
			||||
    if (library == framework::LibraryType::kPlain &&
 | 
				
			||||
        platform::CanMKLDNNBeUsed(ctx)) {
 | 
				
			||||
      library = framework::LibraryType::kMKLDNN;
 | 
				
			||||
      layout = framework::DataLayout::kMKLDNN;
 | 
				
			||||
 | 
				
			||||
      if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
 | 
				
			||||
          input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
 | 
				
			||||
        customized_type_value = kMULMKLDNNINT8;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
    return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
 | 
				
			||||
                                   library, customized_type_value);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class AddMMOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("Input", "(Tensor), tensor to be added to the final result.");
 | 
				
			||||
    AddInput("X", "(Tensor), The first input tensor for mul.");
 | 
				
			||||
    AddInput("Y", "(Tensor), The second input tensor for mul.");
 | 
				
			||||
    AddOutput("Out", "(Tensor), The output tensor of addmm op.");
 | 
				
			||||
    AddAttr<bool>("use_mkldnn",
 | 
				
			||||
                  "(bool, default false) Only used in mkldnn kernel")
 | 
				
			||||
        .SetDefault(false);
 | 
				
			||||
    AddAttr<float>("Alpha", "coefficient of x*y.").SetDefault(1.0f);
 | 
				
			||||
    AddAttr<float>("Beta", "coefficient of input.").SetDefault(1.0f);
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
AddMM Operator.
 | 
				
			||||
This operator is used to perform matrix multiplication for input $x$ and $y$ with coefficient $alpha$.
 | 
				
			||||
$input$ with coefficient $beta$ is added to the final result. 
 | 
				
			||||
The equation is:
 | 
				
			||||
 | 
				
			||||
$$Out = alpha * x * y + beta * input$$
 | 
				
			||||
 | 
				
			||||
$x$ and $y$ must be two-dimensional, and $input$ can be broadcastable.
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class AddMMGradOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput("Input"), true,
 | 
				
			||||
        platform::errors::NotFound("Input(Input) should not be null"));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput("X"), true,
 | 
				
			||||
        platform::errors::NotFound("Input(X) should not be null"));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput("Y"), true,
 | 
				
			||||
        platform::errors::NotFound("Input(Y) should not be null"));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        ctx->HasInput(framework::GradVarName("Out")), true,
 | 
				
			||||
        platform::errors::NotFound("Input(Out@GRAD) should not be null"));
 | 
				
			||||
    const auto& input_dims = ctx->GetInputDim("Input");
 | 
				
			||||
    const auto& x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    const auto& y_dims = ctx->GetInputDim("Y");
 | 
				
			||||
 | 
				
			||||
    auto input_grad_name = framework::GradVarName("Input");
 | 
				
			||||
    auto x_grad_name = framework::GradVarName("X");
 | 
				
			||||
    auto y_grad_name = framework::GradVarName("Y");
 | 
				
			||||
 | 
				
			||||
    if (ctx->HasOutput(input_grad_name)) {
 | 
				
			||||
      ctx->SetOutputDim(input_grad_name, input_dims);
 | 
				
			||||
    }
 | 
				
			||||
    if (ctx->HasOutput(x_grad_name)) {
 | 
				
			||||
      ctx->SetOutputDim(x_grad_name, x_dims);
 | 
				
			||||
    }
 | 
				
			||||
    if (ctx->HasOutput(y_grad_name)) {
 | 
				
			||||
      ctx->SetOutputDim(y_grad_name, y_dims);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  void Apply(GradOpPtr<T> retv) const override {
 | 
				
			||||
    retv->SetType("addmm_grad");
 | 
				
			||||
    retv->SetInput("Input", this->Input("Input"));
 | 
				
			||||
    retv->SetInput("X", this->Input("X"));
 | 
				
			||||
    retv->SetInput("Y", this->Input("Y"));
 | 
				
			||||
    retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
 | 
				
			||||
    retv->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
 | 
				
			||||
    retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
 | 
				
			||||
    retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
 | 
				
			||||
    retv->SetAttrMap(this->Attrs());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
 | 
				
			||||
REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
 | 
				
			||||
                  ops::AddMMOpGradMaker<paddle::framework::OpDesc>,
 | 
				
			||||
                  ops::AddMMOpGradMaker<paddle::imperative::OpBase>);
 | 
				
			||||
 | 
				
			||||
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(
 | 
				
			||||
    addmm, ops::AddMMKernel<paddle::platform::CPUDeviceContext, float>,
 | 
				
			||||
    ops::AddMMKernel<paddle::platform::CPUDeviceContext, double>);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(
 | 
				
			||||
    addmm_grad, ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, float>,
 | 
				
			||||
    ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, double>);
 | 
				
			||||
@ -0,0 +1,24 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/addmm_op.h"
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
namespace plat = paddle::platform;
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(addmm, ops::AddMMKernel<plat::CUDADeviceContext, float>,
 | 
				
			||||
                        ops::AddMMKernel<plat::CUDADeviceContext, double>);
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(addmm_grad,
 | 
				
			||||
                        ops::AddMMGradKernel<plat::CUDADeviceContext, float>,
 | 
				
			||||
                        ops::AddMMGradKernel<plat::CUDADeviceContext, double>);
 | 
				
			||||
@ -0,0 +1,193 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <boost/preprocessor/repetition/repeat.hpp>
 | 
				
			||||
#include "paddle/fluid/framework/eigen.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/framework/operator.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/blas.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/math_function.h"
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
namespace plat = paddle::platform;
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
 | 
				
			||||
          typename IndexType = Eigen::DenseIndex>
 | 
				
			||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
 | 
				
			||||
 | 
				
			||||
using Array1 = Eigen::DSizes<int64_t, 1>;
 | 
				
			||||
using Array2 = Eigen::DSizes<int64_t, 2>;
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
 | 
				
			||||
constexpr int kMULMKLDNNINT8 = 1;
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class AddMMKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    const Tensor* input = context.Input<Tensor>("Input");
 | 
				
			||||
    const Tensor* x = context.Input<Tensor>("X");
 | 
				
			||||
    const Tensor* y = context.Input<Tensor>("Y");
 | 
				
			||||
 | 
				
			||||
    auto input_dims = input->dims();
 | 
				
			||||
    auto x_dims = x->dims();
 | 
				
			||||
    auto y_dims = y->dims();
 | 
				
			||||
 | 
				
			||||
    // broadcast mode check
 | 
				
			||||
    if (x_dims[0] != input_dims[0]) {
 | 
				
			||||
      PADDLE_ENFORCE_EQ(input_dims[0], 1,
 | 
				
			||||
                        platform::errors::InvalidArgument(
 | 
				
			||||
                            "When x_dims[0] is not equal with input_dims[0], "
 | 
				
			||||
                            "input_dims[0] must be 1 but got %s",
 | 
				
			||||
                            input_dims[0]));
 | 
				
			||||
      PADDLE_ENFORCE_EQ(
 | 
				
			||||
          y_dims[1] == input_dims[1] || input_dims[1] == 1, true,
 | 
				
			||||
          platform::errors::InvalidArgument(
 | 
				
			||||
              "The input tensor shape mismatch, input shape=[%s], "
 | 
				
			||||
              "x shape=[%s], y shape=[%s]",
 | 
				
			||||
              input_dims, x_dims, y_dims));
 | 
				
			||||
    }
 | 
				
			||||
    // broadcast mode check
 | 
				
			||||
    if (y_dims[1] != input_dims[1]) {
 | 
				
			||||
      PADDLE_ENFORCE_EQ(input_dims[1], 1,
 | 
				
			||||
                        platform::errors::InvalidArgument(
 | 
				
			||||
                            "When y_dims[1] is not equal with input_dims[0], "
 | 
				
			||||
                            "input_dims[0] must be 1 but got %s",
 | 
				
			||||
                            input_dims[1]));
 | 
				
			||||
      PADDLE_ENFORCE_EQ(
 | 
				
			||||
          x_dims[0] == input_dims[0] || input_dims[0] == 1, true,
 | 
				
			||||
          platform::errors::InvalidArgument(
 | 
				
			||||
              "The input tensor shape mismatch, input shape=[%s], "
 | 
				
			||||
              "x shape=[%s], y shape=[%s]",
 | 
				
			||||
              input_dims, x_dims, y_dims));
 | 
				
			||||
    }
 | 
				
			||||
    // broadcast mode check
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        x_dims[1], y_dims[0],
 | 
				
			||||
        platform::errors::InvalidArgument(
 | 
				
			||||
            "The input tensor X's width must be equal with matrix Y' height. "
 | 
				
			||||
            "But received X's shape = [%s], Y's shape = [%s].",
 | 
				
			||||
            x_dims[1], y_dims[0]));
 | 
				
			||||
 | 
				
			||||
    auto* out = context.Output<Tensor>("Out");
 | 
				
			||||
    out->mutable_data<T>({x_dims[0], y_dims[1]}, context.GetPlace());
 | 
				
			||||
 | 
				
			||||
    float alpha = context.template Attr<float>("Alpha");
 | 
				
			||||
    float beta = context.template Attr<float>("Beta");
 | 
				
			||||
 | 
				
			||||
    auto blas = math::GetBlas<DeviceContext, T>(context);
 | 
				
			||||
 | 
				
			||||
    // calc broadcast dim
 | 
				
			||||
    Array2 bcast_dims;
 | 
				
			||||
    bcast_dims[0] = x_dims[0] / input_dims[0];
 | 
				
			||||
    bcast_dims[1] = y_dims[1] / input_dims[1];
 | 
				
			||||
    VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
 | 
				
			||||
    // broadcast using eigen
 | 
				
			||||
    auto eigen_input = EigenTensor<T, 2>::From(*input);
 | 
				
			||||
    auto eigen_out = EigenTensor<T, 2>::From(*out);
 | 
				
			||||
    auto& place =
 | 
				
			||||
        *context.template device_context<DeviceContext>().eigen_device();
 | 
				
			||||
    eigen_out.device(place) = eigen_input.broadcast(bcast_dims);
 | 
				
			||||
 | 
				
			||||
    blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
 | 
				
			||||
              x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
 | 
				
			||||
              out->data<T>(), y_dims[1]);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class AddMMGradKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    auto* x = ctx.Input<framework::LoDTensor>("X");
 | 
				
			||||
    auto* y = ctx.Input<framework::LoDTensor>("Y");
 | 
				
			||||
    auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
 | 
				
			||||
    auto in_dims = ctx.Input<framework::LoDTensor>("Input")->dims();
 | 
				
			||||
    auto* dinput =
 | 
				
			||||
        ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
 | 
				
			||||
    auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
 | 
				
			||||
    auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
 | 
				
			||||
 | 
				
			||||
    float alpha = ctx.Attr<float>("Alpha");
 | 
				
			||||
    float beta = ctx.Attr<float>("Beta");
 | 
				
			||||
 | 
				
			||||
    int total_elems = 0;
 | 
				
			||||
 | 
				
			||||
    VLOG(3) << "alpha: " << alpha << " beta: " << beta;
 | 
				
			||||
 | 
				
			||||
    if (dinput != nullptr) {
 | 
				
			||||
      dinput->set_lod(dout->lod());
 | 
				
			||||
    }
 | 
				
			||||
    if (dx != nullptr) {
 | 
				
			||||
      dx->set_lod(x->lod());
 | 
				
			||||
    }
 | 
				
			||||
    if (dy != nullptr) {
 | 
				
			||||
      dy->set_lod(y->lod());
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
 | 
				
			||||
    if (dinput) {
 | 
				
			||||
      dinput->mutable_data<T>(ctx.GetPlace());
 | 
				
			||||
      total_elems = in_dims[0] * in_dims[1];
 | 
				
			||||
      auto& place =
 | 
				
			||||
          *ctx.template device_context<DeviceContext>().eigen_device();
 | 
				
			||||
      auto eigen_dout = EigenTensor<T, 2>::From(*dout);
 | 
				
			||||
      auto eigen_dinput = EigenTensor<T, 2>::From(*dinput);
 | 
				
			||||
 | 
				
			||||
      bool row_compress = in_dims[0] != dout->dims()[0];
 | 
				
			||||
      bool col_compress = in_dims[1] != dout->dims()[1];
 | 
				
			||||
      auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]);
 | 
				
			||||
 | 
				
			||||
      if (row_compress && col_compress) {
 | 
				
			||||
        eigen_dinput.device(place) =
 | 
				
			||||
            eigen_dout.sum().eval().reshape(eigen_dinput_shape);
 | 
				
			||||
      } else if (row_compress) {
 | 
				
			||||
        eigen_dinput.device(place) =
 | 
				
			||||
            eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
 | 
				
			||||
      } else if (col_compress) {
 | 
				
			||||
        eigen_dinput.device(place) =
 | 
				
			||||
            eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
 | 
				
			||||
      } else {
 | 
				
			||||
        blas.VCOPY(total_elems, dout->data<T>(), dinput->data<T>());
 | 
				
			||||
      }
 | 
				
			||||
 | 
				
			||||
      blas.SCAL(total_elems, beta, dinput->data<T>());
 | 
				
			||||
    }
 | 
				
			||||
    if (dx) {
 | 
				
			||||
      dx->mutable_data<T>(ctx.GetPlace());
 | 
				
			||||
      total_elems = x->dims()[0] * x->dims()[1];
 | 
				
			||||
      // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
 | 
				
			||||
      blas.MatMul(*dout, false, *y, true, dx);
 | 
				
			||||
      blas.SCAL(total_elems, alpha, dx->data<T>());
 | 
				
			||||
    }
 | 
				
			||||
    if (dy) {
 | 
				
			||||
      dy->mutable_data<T>(ctx.GetPlace());
 | 
				
			||||
      total_elems = x->dims()[1] * y->dims()[1];
 | 
				
			||||
      // dy = x' * dout. dy K x N, dout : M x N, x : M x K
 | 
				
			||||
      blas.MatMul(*x, true, *dout, false, dy);
 | 
				
			||||
      blas.SCAL(total_elems, alpha, dy->data<T>());
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,137 @@
 | 
				
			||||
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
import paddle
 | 
				
			||||
import paddle.fluid.core as core
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
from paddle.fluid import Program, program_guard
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestAddMMOp(OpTest):
 | 
				
			||||
    # test basic
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "addmm"
 | 
				
			||||
        self.dtype = np.float64
 | 
				
			||||
        self.init_dtype_type()
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'Input': np.random.random((100, 1)).astype(self.dtype),
 | 
				
			||||
            'X': np.random.random((100, 10)).astype(self.dtype),
 | 
				
			||||
            'Y': np.random.random((10, 20)).astype(self.dtype),
 | 
				
			||||
        }
 | 
				
			||||
        self.outputs = {
 | 
				
			||||
            'Out':
 | 
				
			||||
            self.inputs['Input'] + np.dot(self.inputs['X'], self.inputs['Y'])
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
    def init_dtype_type(self):
 | 
				
			||||
        pass
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_normal(self):
 | 
				
			||||
        self.check_grad(['Input', 'X', 'Y'], 'Out')
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_x(self):
 | 
				
			||||
        self.check_grad(['X'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_y(self):
 | 
				
			||||
        self.check_grad(['Y'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_input(self):
 | 
				
			||||
        self.check_grad(['Input'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestAddMMOpError(unittest.TestCase):
 | 
				
			||||
    # test error
 | 
				
			||||
    def test_errors(self):
 | 
				
			||||
        with program_guard(Program(), Program()):
 | 
				
			||||
            # The input type of addmm_op must be Variable.
 | 
				
			||||
            input = fluid.create_lod_tensor(
 | 
				
			||||
                np.array([[-1]]), [[1]], fluid.CPUPlace())
 | 
				
			||||
            x1 = fluid.create_lod_tensor(
 | 
				
			||||
                np.array([[-1]]), [[1]], fluid.CPUPlace())
 | 
				
			||||
            x2 = fluid.create_lod_tensor(
 | 
				
			||||
                np.array([[-1]]), [[1]], fluid.CPUPlace())
 | 
				
			||||
            self.assertRaises(TypeError, paddle.addmm, input, x1, x2)
 | 
				
			||||
            # The input dtype of mul_op must be float32 or float64.
 | 
				
			||||
            input = fluid.layers.data(name='input', shape=[4], dtype="int32")
 | 
				
			||||
            x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
 | 
				
			||||
            x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
 | 
				
			||||
            self.assertRaises(TypeError, paddle.addmm, input, x3, x4)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestAddMMOp2(TestAddMMOp):
 | 
				
			||||
    # test alpha and beta
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "addmm"
 | 
				
			||||
        self.dtype = np.float64
 | 
				
			||||
        self.init_dtype_type()
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'Input': np.random.random((20, 30)).astype(self.dtype),
 | 
				
			||||
            'X': np.random.random((20, 6)).astype(self.dtype),
 | 
				
			||||
            'Y': np.random.random((6, 30)).astype(self.dtype),
 | 
				
			||||
        }
 | 
				
			||||
        self.attrs = {
 | 
				
			||||
            'Alpha': 0.1,
 | 
				
			||||
            'Beta': 1.0,
 | 
				
			||||
        }
 | 
				
			||||
        self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
 | 
				
			||||
                        self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestAddMMOp3(OpTest):
 | 
				
			||||
    # test broadcast
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "addmm"
 | 
				
			||||
        self.dtype = np.float64
 | 
				
			||||
        self.init_dtype_type()
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'Input': np.random.random((1, 100)).astype(self.dtype),
 | 
				
			||||
            'X': np.random.random((20, 10)).astype(self.dtype),
 | 
				
			||||
            'Y': np.random.random((10, 100)).astype(self.dtype),
 | 
				
			||||
        }
 | 
				
			||||
        self.attrs = {
 | 
				
			||||
            'Alpha': 0.5,
 | 
				
			||||
            'Beta': 2.0,
 | 
				
			||||
        }
 | 
				
			||||
        self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
 | 
				
			||||
                        self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
 | 
				
			||||
 | 
				
			||||
    def init_dtype_type(self):
 | 
				
			||||
        pass
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_normal(self):
 | 
				
			||||
        self.check_grad(['Input', 'X', 'Y'], 'Out')
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_x(self):
 | 
				
			||||
        self.check_grad(['X'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_y(self):
 | 
				
			||||
        self.check_grad(['Y'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
    def test_check_grad_input(self):
 | 
				
			||||
        self.check_grad(['Input'], 'Out', no_grad_set=None)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue