Add cholesky_op (#23543)
* Add cholesky_op forward part. test=develop * Complete cholesky_op forward part. test=develop * Add cholesky_op backward part. test=develop * Complete cholesky_op backward part. test=develop * Refine cholesky_op error check and docs. test=develop * Add grad_check unit test for cholesky_op. test=develop * Fix sample code in cholesky doc. test=develop * Refine some error messages of cholesky_op. test=develop * Refine some error messages of cholesky_op. test=develop * Remove unused input in cholesky_grad. test=develop * Remove unused input in cholesky_grad. test=develop * Fix stream for cusolverDnSetStream. test=develop * Update PADDLE_ENFORCE_CUDA_SUCCESS from cholesky_op to adapt to latest code. test=develop * Add CUSOLVER ERROR in enforce.h test=develop * Fix the missing return value in cholesky. test=developrevert-22778-infer_var_type
parent
461e6a01ec
commit
a8c0fb4e86
@ -0,0 +1,121 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/cholesky_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::OpKernelType;
|
||||
using framework::Tensor;
|
||||
|
||||
class CholeskyOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cholesky");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cholesky");
|
||||
auto dims = ctx->GetInputDim("X");
|
||||
auto rank = dims.size();
|
||||
PADDLE_ENFORCE_GE(rank, 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The Input(X) should have at least 2 dimensions. But "
|
||||
"received a %d dimension tensor.",
|
||||
rank));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dims[rank - 2], dims[rank - 1],
|
||||
platform::errors::InvalidArgument(
|
||||
"The inner-most 2 dimensions of Input(X) all should be symmetric "
|
||||
"positive-definite matrices and have the same size. But received "
|
||||
"X's shape[-2] = %d and shape[-1] = %d.",
|
||||
dims[rank - 2], dims[rank - 1]));
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class CholeskyOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor), The input tensor of cholesky op. Its shape should be "
|
||||
"[*, M, M] where * is zero or more batch dimensions, and matrices "
|
||||
"on the inner-most 2 dimensions all should be symmetric "
|
||||
"positive-definite.");
|
||||
AddOutput("Out",
|
||||
"(Tensor), The output tensor of cholesky op. It has the same "
|
||||
"shape as the input, and it is composed of upper-triangular or "
|
||||
"lower-triangular Cholesky factors of each of the individual "
|
||||
"matrices.");
|
||||
AddAttr<bool>("upper",
|
||||
"(bool, default false), flag indicating whether to return "
|
||||
"upper or lower triangular matrices. Default: False")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Cholesky Operator.
|
||||
|
||||
Computes the Cholesky decomposition of one symmetric positive-definite matrix
|
||||
or batches of symmetric positive-definite matrices.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CholeskyGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CholeskyGrad");
|
||||
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
|
||||
"Out@GRAD", "CholeskyGrad");
|
||||
auto dims = ctx->GetInputDim("Out");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType(this->ForwardOpType() + "_grad");
|
||||
op->SetInput("Out", this->Output("Out"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(cholesky, ops::CholeskyOp, ops::CholeskyOpMaker,
|
||||
ops::CholeskyGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::CholeskyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(cholesky_grad, ops::CholeskyGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(cholesky, ops::CholeskyCPUKernel<float>,
|
||||
ops::CholeskyCPUKernel<double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
cholesky_grad,
|
||||
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,153 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/memory/memory.h"
|
||||
#include "paddle/fluid/operators/cholesky_op.h"
|
||||
#include "paddle/fluid/platform/dynload/cusolver.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CUDADeviceContext>();
|
||||
|
||||
const Tensor* x = context.Input<Tensor>("X");
|
||||
Tensor* out = context.Output<Tensor>("Out");
|
||||
|
||||
bool upper = context.Attr<bool>("upper");
|
||||
auto& dims = x->dims();
|
||||
int batch_count = 1;
|
||||
for (int i = 0; i < dims.size() - 2; i++) {
|
||||
batch_count *= dims[i];
|
||||
}
|
||||
int m = dims[dims.size() - 1];
|
||||
int tensor_size = batch_count * m * m;
|
||||
|
||||
const auto* x_data = x->data<T>();
|
||||
auto* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
// matrices are assumed to be stored in column-major order in cusolver
|
||||
cublasFillMode_t uplo =
|
||||
upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
// portf is inplace, thus copy the triangular part of the input matrices to
|
||||
// the output and set the other triangular part to 0 firstly
|
||||
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
|
||||
tensor_size);
|
||||
if (upper) {
|
||||
MatrixBandPartFunctor<T> matrix_band_part_functor(
|
||||
m, m, /* num_lower_diags */ 0, /* num_upper_diags */ m, x_data,
|
||||
out_data);
|
||||
for_range(matrix_band_part_functor);
|
||||
} else {
|
||||
MatrixBandPartFunctor<T> matrix_band_part_functor(
|
||||
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, x_data,
|
||||
out_data);
|
||||
for_range(matrix_band_part_functor);
|
||||
}
|
||||
|
||||
// TODO(guosheng): Add callback to check info
|
||||
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count);
|
||||
auto* info_ptr = reinterpret_cast<int*>(info->ptr());
|
||||
|
||||
#if CUDA_VERSION >= 9020
|
||||
if (batch_count > 1) {
|
||||
std::vector<T*> output_ptrs;
|
||||
for (int i = 0; i < batch_count; i++) {
|
||||
output_ptrs.emplace_back(out_data + i * m * m);
|
||||
}
|
||||
thrust::device_vector<T*> dev_output_ptrs(output_ptrs.begin(),
|
||||
output_ptrs.end());
|
||||
PotrfBatched(dev_ctx, uplo, m,
|
||||
thrust::raw_pointer_cast(dev_output_ptrs.data()), m,
|
||||
info_ptr, batch_count);
|
||||
// TODO(guosheng): There seems to a bug in cusolver potrfBatched and need
|
||||
// to clear the upper triangle of the output. Remove this workaround once
|
||||
// the bug is fixed.
|
||||
if (!upper) {
|
||||
MatrixBandPartFunctor<T> matrix_band_part_functor(
|
||||
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, out_data,
|
||||
out_data);
|
||||
for_range(matrix_band_part_functor);
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
for (int i = 0; i < batch_count; i++) {
|
||||
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 9020
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Potrf(const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo,
|
||||
int n, T* A, int lda, int* info) const;
|
||||
|
||||
void PotrfBatched(const platform::CUDADeviceContext& dev_ctx,
|
||||
cublasFillMode_t uplo, int n, T* Aarray[], int lda,
|
||||
int* info_array, int batch_size) const;
|
||||
};
|
||||
|
||||
#define FUNC_WITH_TYPES(m) m(float, S) m(double, D)
|
||||
|
||||
#define POTRF_INSTANCE(T, C) \
|
||||
template <> \
|
||||
void CholeskyGPUKernel<T>::Potrf(const platform::CUDADeviceContext& dev_ctx, \
|
||||
cublasFillMode_t uplo, int n, T* A, \
|
||||
int lda, int* info) const { \
|
||||
auto handle = dev_ctx.cusolver_dn_handle(); \
|
||||
int workspace_size = 0; \
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS( \
|
||||
platform::dynload::cusolverDn##C##potrf_bufferSize( \
|
||||
handle, uplo, n, A, lda, &workspace_size)); \
|
||||
auto workspace = memory::Alloc(dev_ctx, workspace_size); \
|
||||
T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##potrf( \
|
||||
handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \
|
||||
}
|
||||
|
||||
FUNC_WITH_TYPES(POTRF_INSTANCE);
|
||||
|
||||
#if CUDA_VERSION >= 9020
|
||||
#define POTRF_BATCH_INSTANCE(T, C) \
|
||||
template <> \
|
||||
void CholeskyGPUKernel<T>::PotrfBatched( \
|
||||
const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \
|
||||
int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \
|
||||
auto handle = dev_ctx.cusolver_dn_handle(); \
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS( \
|
||||
platform::dynload::cusolverDn##C##potrfBatched( \
|
||||
handle, uplo, n, Aarray, lda, info_array, batch_size)); \
|
||||
}
|
||||
|
||||
FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE);
|
||||
#endif
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(cholesky, ops::CholeskyGPUKernel<float>,
|
||||
ops::CholeskyGPUKernel<double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
cholesky_grad,
|
||||
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,30 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/platform/dynload/cusolver.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
std::once_flag cusolver_dso_flag;
|
||||
void *cusolver_dso_handle;
|
||||
|
||||
#define DEFINE_WRAP(__name) DynLoad__##__name __name
|
||||
|
||||
CUSOLVER_ROUTINE_EACH(DEFINE_WRAP);
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#pragma once
|
||||
|
||||
#include <cusolverDn.h>
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
extern std::once_flag cusolver_dso_flag;
|
||||
extern void *cusolver_dso_handle;
|
||||
#ifdef PADDLE_USE_DSO
|
||||
#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
cusolverStatus_t operator()(Args... args) { \
|
||||
using cusolverFunc = decltype(&::__name); \
|
||||
std::call_once(cusolver_dso_flag, []() { \
|
||||
cusolver_dso_handle = \
|
||||
paddle::platform::dynload::GetCusolverDsoHandle(); \
|
||||
}); \
|
||||
static void *p_##__name = dlsym(cusolver_dso_handle, #__name); \
|
||||
return reinterpret_cast<cusolverFunc>(p_##__name)(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern DynLoad__##__name __name
|
||||
#else
|
||||
#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
cusolverStatus_t operator()(Args... args) { \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
}; \
|
||||
extern DynLoad__##__name __name
|
||||
#endif
|
||||
|
||||
#define CUSOLVER_ROUTINE_EACH(__macro) \
|
||||
__macro(cusolverDnCreate); \
|
||||
__macro(cusolverDnDestroy); \
|
||||
__macro(cusolverDnSetStream); \
|
||||
__macro(cusolverDnSpotrf_bufferSize); \
|
||||
__macro(cusolverDnDpotrf_bufferSize); \
|
||||
__macro(cusolverDnSpotrf); \
|
||||
__macro(cusolverDnDpotrf);
|
||||
|
||||
CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
|
||||
|
||||
#if CUDA_VERSION >= 9020
|
||||
#define CUSOLVER_ROUTINE_EACH_R1(__macro) \
|
||||
__macro(cusolverDnSpotrfBatched); \
|
||||
__macro(cusolverDnDpotrfBatched);
|
||||
|
||||
CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
|
||||
#endif
|
||||
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
from gradient_checker import grad_check
|
||||
from decorator_helper import prog_scope
|
||||
|
||||
|
||||
@skip_check_grad_ci(
|
||||
reason="The input of cholesky_op should always be symmetric positive-definite. "
|
||||
"However, OpTest calculates the numeric gradient of each element in input "
|
||||
"via small finite difference, which makes the input no longer symmetric "
|
||||
"positive-definite thus can not compute the Cholesky decomposition. "
|
||||
"While we can use the gradient_checker.grad_check to perform gradient "
|
||||
"check of cholesky_op, since it supports check gradient with a program "
|
||||
"and we can construct symmetric positive-definite matrices in the program")
|
||||
class TestCholeskyOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cholesky"
|
||||
self._input_shape = (2, 32, 32)
|
||||
self._upper = True
|
||||
self.init_config()
|
||||
self.trans_dims = list(range(len(self._input_shape) - 2)) + [
|
||||
len(self._input_shape) - 1, len(self._input_shape) - 2
|
||||
]
|
||||
self.root_data = np.random.random(self._input_shape).astype("float64")
|
||||
# construct symmetric positive-definite matrice
|
||||
input_data = np.matmul(
|
||||
self.root_data, self.root_data.transpose(self.trans_dims)) + 1e-05
|
||||
output_data = np.linalg.cholesky(input_data).astype("float64")
|
||||
if self._upper:
|
||||
output_data = output_data.transpose(self.trans_dims)
|
||||
self.inputs = {"X": input_data}
|
||||
self.attrs = {"upper": self._upper}
|
||||
self.outputs = {"Out": output_data}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
places = [fluid.CPUPlace()]
|
||||
if core.is_compiled_with_cuda():
|
||||
places.append(fluid.CUDAPlace(0))
|
||||
for p in places:
|
||||
self.func(p)
|
||||
|
||||
@prog_scope()
|
||||
def func(self, place):
|
||||
# use small size since Jacobian gradients is time consuming
|
||||
root_data = self.root_data[..., :3, :3]
|
||||
prog = fluid.Program()
|
||||
with fluid.program_guard(prog):
|
||||
root = layers.create_parameter(
|
||||
dtype=root_data.dtype, shape=root_data.shape)
|
||||
root_t = layers.transpose(root, self.trans_dims)
|
||||
x = layers.matmul(x=root, y=root_t) + 1e-05
|
||||
out = paddle.cholesky(x, upper=self.attrs["upper"])
|
||||
grad_check(root, out, x_init=root_data, place=place)
|
||||
|
||||
def init_config(self):
|
||||
self._upper = True
|
||||
|
||||
|
||||
class TestCholeskyOpLower(TestCholeskyOp):
|
||||
def init_config(self):
|
||||
self._upper = False
|
||||
|
||||
|
||||
class TestCholeskyOp2D(TestCholeskyOp):
|
||||
def init_config(self):
|
||||
self._input_shape = (64, 64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue