add kron op (#24105)
* add kron op and its python API, doc and unittests. * add kron in paddle.complexrevert-22778-infer_var_type
parent
eb411613e9
commit
e01262e691
@ -0,0 +1,168 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/operators/kron_op.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class KronOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron");
|
||||
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kron");
|
||||
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
auto dim_y = ctx->GetInputDim("Y");
|
||||
auto rank_x = dim_x.size();
|
||||
auto rank_y = dim_y.size();
|
||||
auto rank = (rank_x > rank_y) ? rank_x : rank_y;
|
||||
|
||||
std::vector<int64_t> dim_out;
|
||||
dim_out.reserve(rank);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
int64_t dim_xi = (i < rank - rank_x) ? 1 : dim_x.at(i - (rank - rank_x));
|
||||
int64_t dim_yi = (i < rank - rank_y) ? 1 : dim_y.at(i - (rank - rank_y));
|
||||
dim_out.push_back(dim_xi == -1 || dim_yi == -1 ? -1 : dim_xi * dim_yi);
|
||||
}
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class KronOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor), the first operand of kron op");
|
||||
AddInput("Y", "(Tensor), the second operand of kron op");
|
||||
AddOutput("Out", "(Tensor), the output of kron op.");
|
||||
AddComment(R"DOC(
|
||||
Kron Operator.
|
||||
|
||||
This operator computes the Kronecker product of two tensors, a
|
||||
composite tensor made of blocks of the second tensor scaled by the
|
||||
first.
|
||||
|
||||
This operator assumes that the rank of the two tensors, $X$ and $Y$
|
||||
are the same, if necessary prepending the smallest with ones. If the
|
||||
shape of $X$ is [$r_0$, $r_1$, ..., $r_N$] and the shape of $Y$ is
|
||||
[$s_0$, $s_1$, ..., $s_N$], then the shape of the output tensor is
|
||||
[$r_{0}s_{0}$, $r_{1}s_{1}$, ..., $r_{N}s_{N}$]. The elements are
|
||||
products of elements from $X$ and $Y$.
|
||||
|
||||
The equation is:
|
||||
$$
|
||||
output[k_{0}, k_{1}, ..., k_{N}] = X[i_{0}, i_{1}, ..., i_{N}] *
|
||||
Y[j_{0}, j_{1}, ..., j_{N}]
|
||||
$$
|
||||
|
||||
where
|
||||
$$
|
||||
k_{t} = i_{t} * s_{t} + j_{t}, t = 0, 1, ..., N
|
||||
$$
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class KronGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron_grad");
|
||||
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad");
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
framework::GradVarName("Out"), "kron_grad");
|
||||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
||||
framework::GradVarName("X"), "kron_grad");
|
||||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output",
|
||||
framework::GradVarName("Y"), "kron_grad");
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ x_grad_name);
|
||||
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
|
||||
ctx->ShareLoD("Y", /*->*/ y_grad_name);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto out_grad_name = framework::GradVarName("Out");
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class KronGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> grad_op) const override {
|
||||
grad_op->SetType("kron_grad");
|
||||
|
||||
grad_op->SetInput("X", this->Input("X"));
|
||||
grad_op->SetInput("Y", this->Input("Y"));
|
||||
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
|
||||
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
|
||||
|
||||
grad_op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker,
|
||||
ops::KronGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::KronGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
kron, ops::KronKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::KronKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::KronKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
|
||||
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
kron_grad, ops::KronGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::KronGradKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
|
||||
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
||||
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/kron_op.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
kron, ops::KronKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::KronKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::KronKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::KronGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::float16>,
|
||||
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
|
||||
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import fluid, tensor
|
||||
import paddle.complex as cpx
|
||||
import paddle.fluid.dygraph as dg
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
class ComplexKronTestCase(unittest.TestCase):
|
||||
def __init__(self, methodName='runTest', x=None, y=None):
|
||||
super(ComplexKronTestCase, self).__init__(methodName)
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def setUp(self):
|
||||
self.ref_result = np.kron(self.x, self.y)
|
||||
|
||||
def runTest(self):
|
||||
place = fluid.CPUPlace()
|
||||
self.test_identity(place)
|
||||
|
||||
if fluid.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
self.test_identity(place)
|
||||
|
||||
def test_identity(self, place):
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(self.x)
|
||||
y_var = dg.to_variable(self.y)
|
||||
out_var = cpx.kron(x_var, y_var)
|
||||
np.testing.assert_allclose(out_var.numpy(), self.ref_result)
|
||||
|
||||
|
||||
def load_tests(loader, standard_tests, pattern):
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(
|
||||
ComplexKronTestCase(
|
||||
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2),
|
||||
y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3)))
|
||||
suite.addTest(
|
||||
ComplexKronTestCase(
|
||||
x=np.random.randn(2, 2),
|
||||
y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3)))
|
||||
suite.addTest(
|
||||
ComplexKronTestCase(
|
||||
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2),
|
||||
y=np.random.randn(3, 3)))
|
||||
|
||||
suite.addTest(
|
||||
ComplexKronTestCase(
|
||||
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2),
|
||||
y=np.random.randn(2, 2, 3)))
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
class TestKronOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "kron"
|
||||
self.dtype = self._init_dtype()
|
||||
x = np.random.uniform(size=(10, 10)).astype(self.dtype)
|
||||
y = np.random.uniform(size=(10, 10)).astype(self.dtype)
|
||||
out_ref = np.kron(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': out_ref}
|
||||
|
||||
def _init_dtype(self):
|
||||
return "float64"
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
|
||||
class TestKronOp2(TestKronOp):
|
||||
def setUp(self):
|
||||
self.op_type = "kron"
|
||||
self.dtype = self._init_dtype()
|
||||
x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype)
|
||||
y = np.random.uniform(size=(10, 10)).astype(self.dtype)
|
||||
out_ref = np.kron(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': out_ref}
|
||||
|
||||
|
||||
class TestKronOp3(TestKronOp):
|
||||
def setUp(self):
|
||||
self.op_type = "kron"
|
||||
self.dtype = self._init_dtype()
|
||||
x = np.random.uniform(size=(10, 10)).astype(self.dtype)
|
||||
y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype)
|
||||
out_ref = np.kron(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': out_ref}
|
||||
|
||||
|
||||
class TestKronLayer(unittest.TestCase):
|
||||
def test_case(self):
|
||||
a = np.random.randn(10, 10).astype(np.float64)
|
||||
b = np.random.randn(10, 10).astype(np.float64)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
with dg.guard(place):
|
||||
a_var = dg.to_variable(a)
|
||||
b_var = dg.to_variable(b)
|
||||
c_var = paddle.kron(a_var, b_var)
|
||||
np.testing.assert_allclose(c_var.numpy(), np.kron(a, b))
|
||||
|
||||
def test_case_with_output(self):
|
||||
a = np.random.randn(10, 10).astype(np.float64)
|
||||
b = np.random.randn(10, 10).astype(np.float64)
|
||||
|
||||
main = fluid.Program()
|
||||
start = fluid.Program()
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, start):
|
||||
a_var = fluid.data("a", [-1, -1], dtype="float64")
|
||||
b_var = fluid.data("b", [-1, -1], dtype="float64")
|
||||
out_var = fluid.layers.create_tensor("float64", "c")
|
||||
paddle.kron(a_var, b_var, out=out_var)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(start)
|
||||
c, = exe.run(main, feed={'a': a, 'b': b}, fetch_list=[out_var])
|
||||
np.testing.assert_allclose(c, np.kron(a, b))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue