parent
7b5e23c034
commit
1c08a2136e
@ -0,0 +1,236 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/addmm_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::OpKernelType;
|
||||
using framework::Tensor;
|
||||
|
||||
class AddMMOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
||||
platform::errors::NotFound(
|
||||
"Input(Input) of AddMMOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::NotFound("Input(X) of AddMMOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Y"), true,
|
||||
platform::errors::NotFound("Input(Y) of AddMMOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound(
|
||||
"Output(Out) of AddMMOp should not be null."));
|
||||
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
auto ndim_input = input_dims.size();
|
||||
auto ndim_x = x_dims.size();
|
||||
auto ndim_y = y_dims.size();
|
||||
|
||||
float alpha = ctx->Attrs().Get<float>("Alpha");
|
||||
float beta = ctx->Attrs().Get<float>("Beta");
|
||||
|
||||
VLOG(3) << "addmm operator input.shape=" << input_dims
|
||||
<< " x.shape=" << x_dims << " y.shape=" << y_dims
|
||||
<< " beta=" << beta << " alpha=" << alpha
|
||||
<< " ndim_input=" << ndim_input << " ndim_x=" << ndim_x
|
||||
<< " ndim_y=" << ndim_y;
|
||||
|
||||
PADDLE_ENFORCE_NE(framework::product(input_dims), 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The Input variable Input(%s) has not "
|
||||
"been initialized. You may need to confirm "
|
||||
"if you put exe.run(startup_program) "
|
||||
"after optimizer.minimize function.",
|
||||
ctx->Inputs("Input").front()));
|
||||
|
||||
PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The Input variable X(%s) has not "
|
||||
"been initialized. You may need to confirm "
|
||||
"if you put exe.run(startup_program) "
|
||||
"after optimizer.minimize function.",
|
||||
ctx->Inputs("X").front()));
|
||||
|
||||
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The Input variable Y(%s) has not "
|
||||
"been initialized. You may need to confirm "
|
||||
"if you put exe.run(startup_program) "
|
||||
"after optimizer.minimize function.",
|
||||
ctx->Inputs("Y").front()));
|
||||
// dim check
|
||||
PADDLE_ENFORCE_EQ(ndim_input, 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor input's dimension must be 2. "
|
||||
"But received input's dimension = [%s].",
|
||||
ndim_input));
|
||||
PADDLE_ENFORCE_EQ(ndim_x, 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor x's dimension must be 2. "
|
||||
"But received x's dimension = [%s].",
|
||||
ndim_x));
|
||||
PADDLE_ENFORCE_EQ(ndim_y, 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor y's dimension must be 2. "
|
||||
"But received y's dimension = [%s].",
|
||||
ndim_y));
|
||||
|
||||
std::vector<int64_t> output_dims;
|
||||
output_dims.push_back(x_dims[0]);
|
||||
output_dims.push_back(y_dims[1]);
|
||||
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
|
||||
ctx->ShareLoD("Input", /*->*/ "Out");
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
framework::LibraryType library = framework::LibraryType::kPlain;
|
||||
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
||||
int customized_type_value =
|
||||
framework::OpKernelType::kDefaultCustomizedTypeValue;
|
||||
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
if (library == framework::LibraryType::kPlain &&
|
||||
platform::CanMKLDNNBeUsed(ctx)) {
|
||||
library = framework::LibraryType::kMKLDNN;
|
||||
layout = framework::DataLayout::kMKLDNN;
|
||||
|
||||
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
|
||||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
|
||||
customized_type_value = kMULMKLDNNINT8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
||||
library, customized_type_value);
|
||||
}
|
||||
};
|
||||
|
||||
class AddMMOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input", "(Tensor), tensor to be added to the final result.");
|
||||
AddInput("X", "(Tensor), The first input tensor for mul.");
|
||||
AddInput("Y", "(Tensor), The second input tensor for mul.");
|
||||
AddOutput("Out", "(Tensor), The output tensor of addmm op.");
|
||||
AddAttr<bool>("use_mkldnn",
|
||||
"(bool, default false) Only used in mkldnn kernel")
|
||||
.SetDefault(false);
|
||||
AddAttr<float>("Alpha", "coefficient of x*y.").SetDefault(1.0f);
|
||||
AddAttr<float>("Beta", "coefficient of input.").SetDefault(1.0f);
|
||||
AddComment(R"DOC(
|
||||
AddMM Operator.
|
||||
This operator is used to perform matrix multiplication for input $x$ and $y$ with coefficient $alpha$.
|
||||
$input$ with coefficient $beta$ is added to the final result.
|
||||
The equation is:
|
||||
|
||||
$$Out = alpha * x * y + beta * input$$
|
||||
|
||||
$x$ and $y$ must be two-dimensional, and $input$ can be broadcastable.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class AddMMGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Input"), true,
|
||||
platform::errors::NotFound("Input(Input) should not be null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::NotFound("Input(X) should not be null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Y"), true,
|
||||
platform::errors::NotFound("Input(Y) should not be null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
|
||||
const auto& input_dims = ctx->GetInputDim("Input");
|
||||
const auto& x_dims = ctx->GetInputDim("X");
|
||||
const auto& y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
auto input_grad_name = framework::GradVarName("Input");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
|
||||
if (ctx->HasOutput(input_grad_name)) {
|
||||
ctx->SetOutputDim(input_grad_name, input_dims);
|
||||
}
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> retv) const override {
|
||||
retv->SetType("addmm_grad");
|
||||
retv->SetInput("Input", this->Input("Input"));
|
||||
retv->SetInput("X", this->Input("X"));
|
||||
retv->SetInput("Y", this->Input("Y"));
|
||||
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
retv->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
|
||||
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
|
||||
retv->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
|
||||
ops::AddMMOpGradMaker<paddle::framework::OpDesc>,
|
||||
ops::AddMMOpGradMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
addmm, ops::AddMMKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::AddMMKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
addmm_grad, ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/addmm_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(addmm, ops::AddMMKernel<plat::CUDADeviceContext, float>,
|
||||
ops::AddMMKernel<plat::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(addmm_grad,
|
||||
ops::AddMMGradKernel<plat::CUDADeviceContext, float>,
|
||||
ops::AddMMGradKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,193 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/preprocessor/repetition/repeat.hpp>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
||||
|
||||
using Array1 = Eigen::DSizes<int64_t, 1>;
|
||||
using Array2 = Eigen::DSizes<int64_t, 2>;
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
constexpr int kMULMKLDNNINT8 = 1;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AddMMKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* input = context.Input<Tensor>("Input");
|
||||
const Tensor* x = context.Input<Tensor>("X");
|
||||
const Tensor* y = context.Input<Tensor>("Y");
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto x_dims = x->dims();
|
||||
auto y_dims = y->dims();
|
||||
|
||||
// broadcast mode check
|
||||
if (x_dims[0] != input_dims[0]) {
|
||||
PADDLE_ENFORCE_EQ(input_dims[0], 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"When x_dims[0] is not equal with input_dims[0], "
|
||||
"input_dims[0] must be 1 but got %s",
|
||||
input_dims[0]));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
y_dims[1] == input_dims[1] || input_dims[1] == 1, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor shape mismatch, input shape=[%s], "
|
||||
"x shape=[%s], y shape=[%s]",
|
||||
input_dims, x_dims, y_dims));
|
||||
}
|
||||
// broadcast mode check
|
||||
if (y_dims[1] != input_dims[1]) {
|
||||
PADDLE_ENFORCE_EQ(input_dims[1], 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"When y_dims[1] is not equal with input_dims[0], "
|
||||
"input_dims[0] must be 1 but got %s",
|
||||
input_dims[1]));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_dims[0] == input_dims[0] || input_dims[0] == 1, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor shape mismatch, input shape=[%s], "
|
||||
"x shape=[%s], y shape=[%s]",
|
||||
input_dims, x_dims, y_dims));
|
||||
}
|
||||
// broadcast mode check
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_dims[1], y_dims[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor X's width must be equal with matrix Y' height. "
|
||||
"But received X's shape = [%s], Y's shape = [%s].",
|
||||
x_dims[1], y_dims[0]));
|
||||
|
||||
auto* out = context.Output<Tensor>("Out");
|
||||
out->mutable_data<T>({x_dims[0], y_dims[1]}, context.GetPlace());
|
||||
|
||||
float alpha = context.template Attr<float>("Alpha");
|
||||
float beta = context.template Attr<float>("Beta");
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
|
||||
// calc broadcast dim
|
||||
Array2 bcast_dims;
|
||||
bcast_dims[0] = x_dims[0] / input_dims[0];
|
||||
bcast_dims[1] = y_dims[1] / input_dims[1];
|
||||
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
|
||||
// broadcast using eigen
|
||||
auto eigen_input = EigenTensor<T, 2>::From(*input);
|
||||
auto eigen_out = EigenTensor<T, 2>::From(*out);
|
||||
auto& place =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
eigen_out.device(place) = eigen_input.broadcast(bcast_dims);
|
||||
|
||||
blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
|
||||
x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
|
||||
out->data<T>(), y_dims[1]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AddMMGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<framework::LoDTensor>("X");
|
||||
auto* y = ctx.Input<framework::LoDTensor>("Y");
|
||||
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
|
||||
auto in_dims = ctx.Input<framework::LoDTensor>("Input")->dims();
|
||||
auto* dinput =
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
|
||||
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
||||
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
|
||||
|
||||
float alpha = ctx.Attr<float>("Alpha");
|
||||
float beta = ctx.Attr<float>("Beta");
|
||||
|
||||
int total_elems = 0;
|
||||
|
||||
VLOG(3) << "alpha: " << alpha << " beta: " << beta;
|
||||
|
||||
if (dinput != nullptr) {
|
||||
dinput->set_lod(dout->lod());
|
||||
}
|
||||
if (dx != nullptr) {
|
||||
dx->set_lod(x->lod());
|
||||
}
|
||||
if (dy != nullptr) {
|
||||
dy->set_lod(y->lod());
|
||||
}
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
||||
if (dinput) {
|
||||
dinput->mutable_data<T>(ctx.GetPlace());
|
||||
total_elems = in_dims[0] * in_dims[1];
|
||||
auto& place =
|
||||
*ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto eigen_dout = EigenTensor<T, 2>::From(*dout);
|
||||
auto eigen_dinput = EigenTensor<T, 2>::From(*dinput);
|
||||
|
||||
bool row_compress = in_dims[0] != dout->dims()[0];
|
||||
bool col_compress = in_dims[1] != dout->dims()[1];
|
||||
auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]);
|
||||
|
||||
if (row_compress && col_compress) {
|
||||
eigen_dinput.device(place) =
|
||||
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
|
||||
} else if (row_compress) {
|
||||
eigen_dinput.device(place) =
|
||||
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
|
||||
} else if (col_compress) {
|
||||
eigen_dinput.device(place) =
|
||||
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
|
||||
} else {
|
||||
blas.VCOPY(total_elems, dout->data<T>(), dinput->data<T>());
|
||||
}
|
||||
|
||||
blas.SCAL(total_elems, beta, dinput->data<T>());
|
||||
}
|
||||
if (dx) {
|
||||
dx->mutable_data<T>(ctx.GetPlace());
|
||||
total_elems = x->dims()[0] * x->dims()[1];
|
||||
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
|
||||
blas.MatMul(*dout, false, *y, true, dx);
|
||||
blas.SCAL(total_elems, alpha, dx->data<T>());
|
||||
}
|
||||
if (dy) {
|
||||
dy->mutable_data<T>(ctx.GetPlace());
|
||||
total_elems = x->dims()[1] * y->dims()[1];
|
||||
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
|
||||
blas.MatMul(*x, true, *dout, false, dy);
|
||||
blas.SCAL(total_elems, alpha, dy->data<T>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestAddMMOp(OpTest):
|
||||
# test basic
|
||||
def setUp(self):
|
||||
self.op_type = "addmm"
|
||||
self.dtype = np.float64
|
||||
self.init_dtype_type()
|
||||
self.inputs = {
|
||||
'Input': np.random.random((100, 1)).astype(self.dtype),
|
||||
'X': np.random.random((100, 10)).astype(self.dtype),
|
||||
'Y': np.random.random((10, 20)).astype(self.dtype),
|
||||
}
|
||||
self.outputs = {
|
||||
'Out':
|
||||
self.inputs['Input'] + np.dot(self.inputs['X'], self.inputs['Y'])
|
||||
}
|
||||
|
||||
def init_dtype_type(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['Input', 'X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_x(self):
|
||||
self.check_grad(['X'], 'Out', no_grad_set=None)
|
||||
|
||||
def test_check_grad_y(self):
|
||||
self.check_grad(['Y'], 'Out', no_grad_set=None)
|
||||
|
||||
def test_check_grad_input(self):
|
||||
self.check_grad(['Input'], 'Out', no_grad_set=None)
|
||||
|
||||
|
||||
class TestAddMMOpError(unittest.TestCase):
|
||||
# test error
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
# The input type of addmm_op must be Variable.
|
||||
input = fluid.create_lod_tensor(
|
||||
np.array([[-1]]), [[1]], fluid.CPUPlace())
|
||||
x1 = fluid.create_lod_tensor(
|
||||
np.array([[-1]]), [[1]], fluid.CPUPlace())
|
||||
x2 = fluid.create_lod_tensor(
|
||||
np.array([[-1]]), [[1]], fluid.CPUPlace())
|
||||
self.assertRaises(TypeError, paddle.addmm, input, x1, x2)
|
||||
# The input dtype of mul_op must be float32 or float64.
|
||||
input = fluid.layers.data(name='input', shape=[4], dtype="int32")
|
||||
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
|
||||
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
|
||||
self.assertRaises(TypeError, paddle.addmm, input, x3, x4)
|
||||
|
||||
|
||||
class TestAddMMOp2(TestAddMMOp):
|
||||
# test alpha and beta
|
||||
def setUp(self):
|
||||
self.op_type = "addmm"
|
||||
self.dtype = np.float64
|
||||
self.init_dtype_type()
|
||||
self.inputs = {
|
||||
'Input': np.random.random((20, 30)).astype(self.dtype),
|
||||
'X': np.random.random((20, 6)).astype(self.dtype),
|
||||
'Y': np.random.random((6, 30)).astype(self.dtype),
|
||||
}
|
||||
self.attrs = {
|
||||
'Alpha': 0.1,
|
||||
'Beta': 1.0,
|
||||
}
|
||||
self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
|
||||
self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
|
||||
|
||||
|
||||
class TestAddMMOp3(OpTest):
|
||||
# test broadcast
|
||||
def setUp(self):
|
||||
self.op_type = "addmm"
|
||||
self.dtype = np.float64
|
||||
self.init_dtype_type()
|
||||
self.inputs = {
|
||||
'Input': np.random.random((1, 100)).astype(self.dtype),
|
||||
'X': np.random.random((20, 10)).astype(self.dtype),
|
||||
'Y': np.random.random((10, 100)).astype(self.dtype),
|
||||
}
|
||||
self.attrs = {
|
||||
'Alpha': 0.5,
|
||||
'Beta': 2.0,
|
||||
}
|
||||
self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
|
||||
self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
|
||||
|
||||
def init_dtype_type(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['Input', 'X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_x(self):
|
||||
self.check_grad(['X'], 'Out', no_grad_set=None)
|
||||
|
||||
def test_check_grad_y(self):
|
||||
self.check_grad(['Y'], 'Out', no_grad_set=None)
|
||||
|
||||
def test_check_grad_input(self):
|
||||
self.check_grad(['Input'], 'Out', no_grad_set=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue