|
|
|
@ -14,8 +14,8 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
|
|
|
|
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
|
|
|
|
#ifndef MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H
|
|
|
|
|
#define MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H
|
|
|
|
|
#include <cublas_v2.h>
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -29,10 +29,10 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CholeskyGpuKernel : public GpuKernel {
|
|
|
|
|
class CholeskyTrsmGpuKernel : public GpuKernel {
|
|
|
|
|
public:
|
|
|
|
|
CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
|
|
|
|
|
~CholeskyGpuKernel() = default;
|
|
|
|
|
CholeskyTrsmGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
|
|
|
|
|
~CholeskyTrsmGpuKernel() = default;
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
|
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
|
|
|
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
|
|
|
@ -111,12 +111,12 @@ class CholeskyGpuKernel : public GpuKernel {
|
|
|
|
|
if (in_shape.size() == 2) {
|
|
|
|
|
batch_ = 1;
|
|
|
|
|
if (in_shape[0] != in_shape[1]) {
|
|
|
|
|
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
|
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
|
|
|
|
|
}
|
|
|
|
|
} else if (in_shape.size() == 3) {
|
|
|
|
|
batch_ = SizeToInt(in_shape[0]);
|
|
|
|
|
if (in_shape[1] != in_shape[2]) {
|
|
|
|
|
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
|
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
|
|
|
|
@ -140,12 +140,12 @@ class CholeskyGpuKernel : public GpuKernel {
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
} else {
|
|
|
|
|
if (in_shape.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
|
|
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Input Rank as 2.";
|
|
|
|
|
}
|
|
|
|
|
height = in_shape[0];
|
|
|
|
|
width = in_shape[1];
|
|
|
|
|
if (height != width) {
|
|
|
|
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
|
|
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Square Matrix as Input.";
|
|
|
|
|
}
|
|
|
|
|
if (SizeToInt(height) <= split_dim) {
|
|
|
|
|
use_split_matrix = false;
|