Add the implementation of inverse (#23310)
parent
34122e665e
commit
ecfddebbef
@ -0,0 +1,129 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/inverse_op.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class InverseOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Inverse");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Inverse");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
int64_t input_rank = input_dims.size();
|
||||
PADDLE_ENFORCE_GE(
|
||||
input_rank, 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The dimension of Input(Input) is expected to be no less than 2. "
|
||||
"But recieved: Input(Input)'s dimension = %d, shape = [%s].",
|
||||
input_rank, input_dims));
|
||||
if (input_dims[input_rank - 2] > 0 && input_dims[input_rank - 1] > 0) {
|
||||
PADDLE_ENFORCE_EQ(input_dims[input_rank - 2], input_dims[input_rank - 1],
|
||||
platform::errors::InvalidArgument(
|
||||
"The last two dimensions are expected to be equal. "
|
||||
"But recieved: %d and %d; "
|
||||
"Input(Input)'s shape = [%s].",
|
||||
input_dims[input_rank - 2],
|
||||
input_dims[input_rank - 1], input_dims));
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Output", input_dims);
|
||||
ctx->ShareLoD("Input", /*->*/ "Output");
|
||||
}
|
||||
};
|
||||
|
||||
class InverseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
|
||||
protected:
|
||||
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
|
||||
const override {
|
||||
static std::unordered_map<std::string, std::string> m{
|
||||
{"Input", /*->*/ "Output"}};
|
||||
return m;
|
||||
}
|
||||
};
|
||||
|
||||
class InverseGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
auto input_grad = framework::GradVarName("Input");
|
||||
auto output_grad = framework::GradVarName("Output");
|
||||
|
||||
OP_INOUT_CHECK(ctx->HasInput("Output"), "Input", "Output", "InverseGrad");
|
||||
OP_INOUT_CHECK(ctx->HasInput(output_grad), "Input", output_grad,
|
||||
"InverseGrad");
|
||||
|
||||
if (ctx->HasOutput(input_grad)) {
|
||||
ctx->SetOutputDim(input_grad, ctx->GetInputDim(output_grad));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class InverseOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"Input",
|
||||
"(Tensor) A square matrix (2-D Tensor) or batches of square matrices"
|
||||
" to inverse.");
|
||||
AddOutput("Output", "(Tensor) The inverse of input matrix.");
|
||||
AddComment(R"DOC(
|
||||
Inverse Operator
|
||||
|
||||
Takes the inverse of the square matrix.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InverseGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> grad) const override {
|
||||
grad->SetType(this->ForwardOpType() + "_grad");
|
||||
grad->SetInput("Output", this->Output("Output"));
|
||||
grad->SetInput(framework::GradVarName("Output"),
|
||||
this->OutputGrad("Output"));
|
||||
grad->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(inverse, ops::InverseOp, ops::InverseOpMaker,
|
||||
ops::InverseOpInferVarType,
|
||||
ops::InverseGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::InverseGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(inverse_grad, ops::InverseGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
inverse, ops::InverseKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::InverseKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
inverse_grad,
|
||||
ops::InverseGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::InverseGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,25 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/inverse_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
inverse, ops::InverseKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::InverseKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
inverse_grad,
|
||||
ops::InverseGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::InverseGradKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,70 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/matrix_inverse.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InverseKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input = context.Input<framework::Tensor>("Input");
|
||||
auto* output = context.Output<framework::Tensor>("Output");
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
math::MatrixInverseFunctor<DeviceContext, T> mat_inv;
|
||||
mat_inv(dev_ctx, *input, output);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InverseGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* a_inv = context.Input<framework::Tensor>("Output");
|
||||
auto* a_inv_grad =
|
||||
context.Input<framework::Tensor>(framework::GradVarName("Output"));
|
||||
auto* a_grad =
|
||||
context.Output<framework::Tensor>(framework::GradVarName("Input"));
|
||||
|
||||
if (a_grad) {
|
||||
a_grad->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
framework::Tensor tmp_out =
|
||||
context.AllocateTmpTensor<T, DeviceContext>(a_inv->dims(), dev_ctx);
|
||||
|
||||
auto mat_dim_a0 =
|
||||
math::CreateMatrixDescriptor(a_inv_grad->dims(), 0, false);
|
||||
auto mat_dim_b0 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true);
|
||||
blas.MatMul(*a_inv_grad, mat_dim_a0, *a_inv, mat_dim_b0, T(1), &tmp_out,
|
||||
T(0));
|
||||
|
||||
auto mat_dim_a1 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true);
|
||||
auto mat_dim_b1 = math::CreateMatrixDescriptor(tmp_out.dims(), 0, false);
|
||||
blas.MatMul(*a_inv, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), a_grad, T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/matrix_inverse.h"
|
||||
#include "Eigen/Core"
|
||||
#include "Eigen/LU"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
class MatrixInverseFunctor<platform::CPUDeviceContext, T> {
|
||||
using Matrix =
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using EigenMatrixMap = Eigen::Map<Matrix>;
|
||||
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
|
||||
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context,
|
||||
const framework::Tensor& a, framework::Tensor* a_inv) {
|
||||
const auto& mat_dims = a.dims();
|
||||
const int rank = mat_dims.size();
|
||||
int n = mat_dims[rank - 1];
|
||||
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
T* a_inv_ptr = a_inv->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
|
||||
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
|
||||
Eigen::PartialPivLU<Matrix> lu;
|
||||
lu.compute(mat);
|
||||
|
||||
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
|
||||
PADDLE_ENFORCE_GT(
|
||||
min_abs_pivot, static_cast<T>(0),
|
||||
platform::errors::InvalidArgument("Input is not invertible."));
|
||||
mat_inv.noalias() = lu.inverse();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class MatrixInverseFunctor<platform::CPUDeviceContext, float>;
|
||||
template class MatrixInverseFunctor<platform::CPUDeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,102 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/matrix_inverse.h"
|
||||
#include "paddle/fluid/memory/malloc.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context,
|
||||
const framework::Tensor& a, framework::Tensor* a_inv) {
|
||||
const auto& mat_dims = a.dims();
|
||||
const int rank = mat_dims.size();
|
||||
int n = mat_dims[rank - 1];
|
||||
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
|
||||
|
||||
memory::allocation::AllocationPtr tmp_gpu_mat_data;
|
||||
const T* gpu_mat = a.data<T>();
|
||||
if (n >= 32) {
|
||||
// Copy all elements of input matrix A to a temporary memory space to
|
||||
// avoid being overriden by getrf.
|
||||
tmp_gpu_mat_data = memory::Alloc(context, a.numel() * sizeof(T));
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
|
||||
tmp_gpu_mat_data->ptr(),
|
||||
boost::get<platform::CUDAPlace>(context.GetPlace()),
|
||||
a.data<void>(), a.numel() * sizeof(T), context.stream());
|
||||
gpu_mat = reinterpret_cast<const T*>(tmp_gpu_mat_data->ptr());
|
||||
}
|
||||
|
||||
std::vector<const T*> cpu_ptrs(batch_size * 2);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cpu_ptrs[i] = gpu_mat + i * n * n;
|
||||
cpu_ptrs[i + batch_size] = a_inv->data<T>() + i * n * n;
|
||||
}
|
||||
|
||||
// Copy the addresses of A and A_inv from host to device.
|
||||
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
|
||||
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
|
||||
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
|
||||
static_cast<void*>(cpu_ptrs.data()),
|
||||
cpu_ptrs.size() * sizeof(T*), context.stream());
|
||||
T** gpu_inv_ptrs =
|
||||
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
|
||||
|
||||
// Allocate device memory for info and pivots.
|
||||
int num_ints = n < 32 ? batch_size : batch_size * (n + 1);
|
||||
memory::allocation::AllocationPtr tmp_gpu_info_data =
|
||||
memory::Alloc(context, num_ints * sizeof(int));
|
||||
int* gpu_info_ptr = reinterpret_cast<int*>(tmp_gpu_info_data->ptr());
|
||||
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
|
||||
|
||||
// This functions in cuBLAS is intended to be used for matrices of small
|
||||
// sizes where the launch overhead is a significant factor.
|
||||
// TODO(Xreki): call function in cusolver for large matrices.
|
||||
if (n < 32) {
|
||||
// cublas<S/D>matinvBatched is a short cut of cublas<S/D>getrfBatched
|
||||
// plus cublas<S/D>getriBatched.
|
||||
// However it only works if N is less than 32. If not, we need to
|
||||
// go through cublas<S/D>getrfBatched and cublas<S/D>getriBatched.
|
||||
blas.BatchedMatInv(n,
|
||||
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
|
||||
gpu_inv_ptrs, gpu_info_ptr, batch_size);
|
||||
} else {
|
||||
// This function performs the LU factorization of each matrix A by the
|
||||
// equation P * A = L * U. L and U are written back to original matrix A,
|
||||
// and diagonal elements of L are discarded.
|
||||
int* gpu_pivot_ptr =
|
||||
reinterpret_cast<int*>(tmp_gpu_info_data->ptr()) + batch_size;
|
||||
blas.BatchedGETRF(n, reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()),
|
||||
gpu_pivot_ptr, gpu_info_ptr, batch_size);
|
||||
|
||||
blas.BatchedGETRI(n,
|
||||
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
|
||||
gpu_pivot_ptr, gpu_inv_ptrs, gpu_info_ptr, batch_size);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class MatrixInverseFunctor<platform::CUDADeviceContext, float>;
|
||||
template class MatrixInverseFunctor<platform::CUDADeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MatrixInverseFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, const framework::Tensor& a,
|
||||
framework::Tensor* a_inv);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
import paddle
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestInverseOp(OpTest):
|
||||
def config(self):
|
||||
self.matrix_shape = [10, 10]
|
||||
self.dtype = "float64"
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "inverse"
|
||||
self.config()
|
||||
|
||||
np.random.seed(123)
|
||||
mat = np.random.random(self.matrix_shape).astype(self.dtype)
|
||||
inverse = np.linalg.inv(mat)
|
||||
|
||||
self.inputs = {'Input': mat}
|
||||
self.outputs = {'Output': inverse}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_grad(self):
|
||||
self.check_grad(['Input'], 'Output')
|
||||
|
||||
|
||||
class TestInverseOpBatched(TestInverseOp):
|
||||
def config(self):
|
||||
self.matrix_shape = [8, 4, 4]
|
||||
self.dtype = "float64"
|
||||
|
||||
|
||||
class TestInverseOpLarge(TestInverseOp):
|
||||
def config(self):
|
||||
self.matrix_shape = [32, 32]
|
||||
self.dtype = "float64"
|
||||
|
||||
def test_grad(self):
|
||||
self.check_grad(['Input'], 'Output', max_relative_error=1e-6)
|
||||
|
||||
|
||||
class TestInverseOpFP32(TestInverseOp):
|
||||
def config(self):
|
||||
self.matrix_shape = [10, 10]
|
||||
self.dtype = "float32"
|
||||
|
||||
def test_grad(self):
|
||||
self.check_grad(['Input'], 'Output', max_relative_error=1e-2)
|
||||
|
||||
|
||||
class TestInverseOpBatchedFP32(TestInverseOpFP32):
|
||||
def config(self):
|
||||
self.matrix_shape = [8, 4, 4]
|
||||
self.dtype = "float32"
|
||||
|
||||
|
||||
class TestInverseOpLargeFP32(TestInverseOpFP32):
|
||||
def config(self):
|
||||
self.matrix_shape = [32, 32]
|
||||
self.dtype = "float32"
|
||||
|
||||
|
||||
class TestInverseAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
np.random.seed(123)
|
||||
self.places = [fluid.CPUPlace()]
|
||||
if core.is_compiled_with_cuda():
|
||||
self.places.append(fluid.CUDAPlace(0))
|
||||
|
||||
def check_static_result(self, place, with_out=False):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
input = fluid.data(name="input", shape=[4, 4], dtype="float64")
|
||||
if with_out:
|
||||
out = fluid.data(name="output", shape=[4, 4], dtype="float64")
|
||||
else:
|
||||
out = None
|
||||
result = paddle.inverse(input=input, out=out)
|
||||
|
||||
input_np = np.random.random([4, 4]).astype("float64")
|
||||
result_np = np.linalg.inv(input_np)
|
||||
|
||||
exe = fluid.Executor(place)
|
||||
fetches = exe.run(fluid.default_main_program(),
|
||||
feed={"input": input_np},
|
||||
fetch_list=[result])
|
||||
self.assertTrue(np.allclose(fetches[0], np.linalg.inv(input_np)))
|
||||
|
||||
def test_static(self):
|
||||
for place in self.places:
|
||||
self.check_static_result(place=place)
|
||||
|
||||
def test_dygraph(self):
|
||||
for place in self.places:
|
||||
with fluid.dygraph.guard(place):
|
||||
input_np = np.random.random([4, 4]).astype("float64")
|
||||
input = fluid.dygraph.to_variable(input_np)
|
||||
result = paddle.inverse(input)
|
||||
self.assertTrue(
|
||||
np.allclose(result.numpy(), np.linalg.inv(input_np)))
|
||||
|
||||
|
||||
class TestInverseAPIError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
input_np = np.random.random([4, 4]).astype("float64")
|
||||
|
||||
# input must be Variable.
|
||||
self.assertRaises(TypeError, paddle.inverse, input_np)
|
||||
|
||||
# The data type of input must be float32 or float64.
|
||||
for dtype in ["bool", "int32", "int64", "float16"]:
|
||||
input = fluid.data(name='input_' + dtype, shape=[4, 4], dtype=dtype)
|
||||
self.assertRaises(TypeError, paddle.inverse, input)
|
||||
|
||||
# When out is set, the data type must be the same as input.
|
||||
input = fluid.data(name='input_1', shape=[4, 4], dtype="float32")
|
||||
out = fluid.data(name='output', shape=[4, 4], dtype="float64")
|
||||
self.assertRaises(TypeError, paddle.inverse, input, out)
|
||||
|
||||
# The number of dimensions of input must be >= 2.
|
||||
input = fluid.data(name='input_2', shape=[4], dtype="float32")
|
||||
self.assertRaises(ValueError, paddle.inverse, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue