add new dot op(#23418)
parent
cdbe5707e9
commit
2fd728a978
@ -0,0 +1,160 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/dot_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DotOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(true, ctx->HasInput("X"),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Input(X) of DotOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(true, ctx->HasInput("Y"),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Input(Y) of DotOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(true, ctx->HasOutput("Out"),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Output(Out) of DotOp should not be null."));
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto x_rank = (size_t)x_dims.size();
|
||||
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"ShapeError: The dimensions of input tensor X (%s) "
|
||||
"should be 1 or 2",
|
||||
x_dims.to_str()));
|
||||
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
true, x_rank == (size_t)y_dims.size(),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"ShapeError: The shape of input tensor Y: %s should match with "
|
||||
"input tenosr X: %s",
|
||||
y_dims.to_str(), x_dims.to_str()));
|
||||
bool shape_match = true;
|
||||
for (size_t i = 0; i < x_rank; ++i) {
|
||||
if (x_dims[i] != y_dims[i]) {
|
||||
shape_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(true, shape_match,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"ShapeError: The shape of input tensor X: %s should "
|
||||
"be exactly the same "
|
||||
"with input tensor Y: %s",
|
||||
x_dims.to_str(), y_dims.to_str()));
|
||||
auto dims = vectorize(x_dims);
|
||||
dims[dims.size() - 1] = 1;
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(dims));
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class DotOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() final {
|
||||
AddInput("X", "(Tensor) The first input tensor. ");
|
||||
AddInput("Y", "(Tensor) The second input tensor. ");
|
||||
AddOutput("Out", "(Tensor) The result tensor.");
|
||||
AddComment("");
|
||||
}
|
||||
};
|
||||
|
||||
class DotGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
true, ctx->HasInput("X"),
|
||||
platform::errors::PreconditionNotMet("Input(X) should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
true, ctx->HasInput("Y"),
|
||||
platform::errors::PreconditionNotMet("Input(Y) should not be null."));
|
||||
PADDLE_ENFORCE_EQ(true, ctx->HasInput(framework::GradVarName("Out")),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Input(Out@GRAD) should not be null."));
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->ShareDim("X", /*->*/ x_grad_name);
|
||||
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->ShareDim("Y", /*->*/ y_grad_name);
|
||||
ctx->ShareLoD("Y", /*->*/ y_grad_name);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class DotOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("dot_grad");
|
||||
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Y", this->Input("Y"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(dot, ops::DotOp, ops::DotOpMaker,
|
||||
ops::DotOpGradMaker<paddle::framework::OpDesc>,
|
||||
ops::DotOpGradMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(dot_grad, ops::DotGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::DotKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/dot_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel<plat::CUDADeviceContext, float>,
|
||||
ops::DotKernel<plat::CUDADeviceContext, double>,
|
||||
ops::DotKernel<plat::CUDADeviceContext, int>,
|
||||
ops::DotKernel<plat::CUDADeviceContext, int64_t>);
|
||||
REGISTER_OP_CUDA_KERNEL(dot_grad,
|
||||
ops::DotGradKernel<plat::CUDADeviceContext, float>,
|
||||
ops::DotGradKernel<plat::CUDADeviceContext, double>,
|
||||
ops::DotGradKernel<plat::CUDADeviceContext, int>,
|
||||
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>);
|
@ -0,0 +1,168 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DotKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* tensor_x = ctx.Input<Tensor>("X");
|
||||
auto* tensor_y = ctx.Input<Tensor>("Y");
|
||||
auto* tensor_out = ctx.Output<Tensor>("Out");
|
||||
tensor_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
#ifdef __NVCC__
|
||||
if (1 == tensor_out->dims().size()) {
|
||||
auto out = framework::EigenScalar<T>::From(*tensor_out);
|
||||
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
|
||||
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
|
||||
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
out.device(dev) = (x * y).sum();
|
||||
} else {
|
||||
auto out = EigenMatrix<T>::From(*tensor_out);
|
||||
auto x = EigenMatrix<T>::From(*tensor_x);
|
||||
auto y = EigenMatrix<T>::From(*tensor_y);
|
||||
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
|
||||
}
|
||||
#else
|
||||
const auto* data_x = tensor_x->data<T>();
|
||||
const auto* data_y = tensor_y->data<T>();
|
||||
auto* data_out = tensor_out->data<T>();
|
||||
|
||||
auto x_dims = tensor_x->dims();
|
||||
auto step = x_dims[x_dims.size() - 1];
|
||||
int size = static_cast<int>(framework::product(x_dims));
|
||||
|
||||
for (int ind = -1, j = 0; j < size; ++j) {
|
||||
if (j % step == 0) {
|
||||
++ind;
|
||||
data_out[ind] = data_x[j] * data_y[j];
|
||||
} else {
|
||||
data_out[ind] += data_x[j] * data_y[j];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DotGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* tensor_x = ctx.Input<Tensor>("X");
|
||||
auto* tensor_y = ctx.Input<Tensor>("Y");
|
||||
auto* tensor_dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* tensor_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* tensor_dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
|
||||
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
|
||||
#ifdef __NVCC__
|
||||
if (1 == tensor_dout->dims().size()) {
|
||||
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
|
||||
|
||||
if (tensor_dx) {
|
||||
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
|
||||
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
|
||||
auto& dev =
|
||||
*ctx.template device_context<DeviceContext>().eigen_device();
|
||||
Eigen::DSizes<int, 1> size(tensor_dx->numel());
|
||||
dx.device(dev) = y * dout.broadcast(size);
|
||||
}
|
||||
|
||||
if (tensor_dy) {
|
||||
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
|
||||
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
|
||||
auto& dev =
|
||||
*ctx.template device_context<DeviceContext>().eigen_device();
|
||||
Eigen::DSizes<int, 1> size(tensor_dy->numel());
|
||||
dy.device(dev) = x * dout.broadcast(size);
|
||||
}
|
||||
} else {
|
||||
auto dout = EigenMatrix<T>::From(*tensor_dout);
|
||||
|
||||
if (tensor_dx) {
|
||||
tensor_dx->mutable_data<T>(ctx.GetPlace());
|
||||
auto y = EigenMatrix<T>::From(*tensor_y);
|
||||
auto dx = EigenMatrix<T>::From(*tensor_dx);
|
||||
auto& dev =
|
||||
*ctx.template device_context<DeviceContext>().eigen_device();
|
||||
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
|
||||
dx.device(dev) = y * dout.broadcast(size);
|
||||
}
|
||||
|
||||
if (tensor_dy) {
|
||||
tensor_dy->mutable_data<T>(ctx.GetPlace());
|
||||
auto x = EigenMatrix<T>::From(*tensor_x);
|
||||
auto dy = EigenMatrix<T>::From(*tensor_dy);
|
||||
auto& dev =
|
||||
*ctx.template device_context<DeviceContext>().eigen_device();
|
||||
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
|
||||
dy.device(dev) = x * dout.broadcast(size);
|
||||
}
|
||||
}
|
||||
#else
|
||||
const auto* data_dout = tensor_dout->data<T>();
|
||||
|
||||
if (tensor_dx) {
|
||||
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
|
||||
const auto* data_y = tensor_y->data<T>();
|
||||
const framework::DDim& dim = tensor_x->dims();
|
||||
size_t N = static_cast<size_t>(framework::product(dim));
|
||||
|
||||
auto step = dim[dim.size() - 1];
|
||||
|
||||
int s = -1;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
if (0 == i % step) ++s;
|
||||
data_dx[i] = data_y[i] * data_dout[s];
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_dy) {
|
||||
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
|
||||
const auto* data_x = tensor_x->data<T>();
|
||||
const framework::DDim& dim = tensor_y->dims();
|
||||
size_t N = static_cast<size_t>(framework::product(dim));
|
||||
|
||||
auto step = dim[dim.size() - 1];
|
||||
|
||||
int s = -1;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
if (0 == i % step) ++s;
|
||||
data_dy[i] = data_x[i] * data_dout[s];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
from paddle.fluid.op import Operator
|
||||
from paddle.fluid import compiler, Program, program_guard
|
||||
|
||||
|
||||
class DotOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "dot"
|
||||
self.init_dtype()
|
||||
self.init_input_output()
|
||||
|
||||
self.inputs = {
|
||||
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
|
||||
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
|
||||
}
|
||||
self.outputs = {'Out': self.out}
|
||||
self.attrs = {}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
|
||||
|
||||
def init_input_output(self):
|
||||
self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype)
|
||||
self.y = np.random.uniform(1, 3, [121]).astype(self.dtype)
|
||||
self.out = np.dot(self.x, self.y)
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float64
|
||||
|
||||
|
||||
class DotOpBatch(DotOp):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.uniform(0.1, 1, [132]).astype(self.dtype).reshape(
|
||||
[11, 12])
|
||||
self.y = np.random.uniform(1, 3, [132]).astype(self.dtype).reshape(
|
||||
[11, 12])
|
||||
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])
|
||||
|
||||
|
||||
class TestDotOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
|
||||
# the input dtype of elementwise_mul must be float16 or float32 or float64 or int32 or int64
|
||||
# float16 only can be set on GPU place
|
||||
x1 = fluid.layers.data(name='x1', shape=[120], dtype="uint8")
|
||||
y1 = fluid.layers.data(name='y1', shape=[120], dtype="uint8")
|
||||
self.assertRaises(Exception, paddle.dot, x1, y1)
|
||||
|
||||
x2 = fluid.layers.data(name='x2', shape=[2, 3], dtype="float32")
|
||||
y2 = fluid.layers.data(name='y2', shape=[2, 3], dtype="float32")
|
||||
self.assertRaises(Exception, paddle.dot, x2, y2)
|
||||
|
||||
x3 = fluid.layers.data(name='x3', shape=[3], dtype="float32")
|
||||
y3 = fluid.layers.data(name='y3', shape=[2, 3], dtype="float32")
|
||||
self.assertRaises(Exception, paddle.dot, x2, y3)
|
||||
|
||||
|
||||
class TestDygraph(unittest.TestCase):
|
||||
def test_dygraph(self):
|
||||
with fluid.dygraph.guard():
|
||||
x1 = fluid.dygraph.to_variable(np.array([1, 3]).astype(np.float32))
|
||||
y1 = fluid.dygraph.to_variable(np.array([2, 5]).astype(np.float32))
|
||||
self.assertTrue(
|
||||
np.allclose(paddle.dot(x1, y1).numpy(), np.array([17])))
|
||||
|
||||
x1 = fluid.dygraph.to_variable(
|
||||
np.array([[1, 3], [3, 5]]).astype(np.float32))
|
||||
y1 = fluid.dygraph.to_variable(
|
||||
np.array([[2, 5], [6, 8]]).astype(np.float32))
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
paddle.dot(x1, y1).numpy(), np.array([[17], [58]])))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue