diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cu new file mode 100644 index 0000000000..bfe5741c29 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cu @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "determinant_triangle_impl.cuh" +template +__global__ void DetTriangleKernel(T *input, T *output, size_t matrix_n_, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = 1; + for (int pos = 0; pos < matrix_n_*matrix_n_; pos += matrix_n_+1) { + output[i] *= input[i * matrix_n_ * matrix_n_ + pos]; + } + } + return; +} + +template +void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { + DetTriangleKernel<<>>(input, output, matrix_n_, count); + return; +} + +__device__ bool dev_error_res = false; + +template +__global__ void CheckTriangleKernel(T *input, int fill_mode_, size_t matrix_n_, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + size_t idx = 0; + if (fill_mode_ == 0) { // UPPER half + for (size_t row = 0; row < matrix_n_; row++) { + for (size_t col = row + 1; col < matrix_n_; col++) { + idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; + if (static_cast(input[idx]) != 0) { + dev_error_res = false; + return; + } + } + } + } else if (fill_mode_ == 1) { // LOWER half + for (size_t row = 0; row < matrix_n_; row++) { + for (size_t col = 0; col < row; col++) { + idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col; + if (static_cast(input[idx]) != 0) { + dev_error_res = false; + return; + } + } + } + } else { + dev_error_res = false; + return; + } + } + dev_error_res = true; + return; +} + +template +bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) { + CheckTriangleKernel<<>>(input, fill_mode_, matrix_n_, count); + bool host_error_res = false; + cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool)); + return host_error_res; +} + +template void DetTriangle(float *input, float *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); +template void DetTriangle(half *input, half *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); +template bool CheckTriangle(float *input, int fill_mode_, size_t matrix_n_, size_t count, + cudaStream_t cuda_stream); +template bool CheckTriangle(half *input, int fill_mode_, size_t matrix_n_, size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh new file mode 100644 index 0000000000..b017fe602f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); +template +bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.cc new file mode 100644 index 0000000000..75246be40d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DetTriangleGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DetTriangleGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h new file mode 100644 index 0000000000..4b94939a44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class DetTriangleGpuKernel : public GpuKernel { + public: + DetTriangleGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {} + ~DetTriangleGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + if (!CheckTriangle(input_addr, fill_mode_, matrix_n_, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr))) { + if (fill_mode_ == 0) { + MS_LOG(ERROR) << "The elements in the upper half of the maxtices should be all 0."; + } else if (fill_mode_ == 1) { + MS_LOG(ERROR) << "The elements in the lower half of the maxtices should be all 0."; + } else { + MS_LOG(ERROR) << "The input matrix should be either upper filled or lower filled."; + } + return false; + } + DetTriangle(input_addr, output_addr, matrix_n_, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but DetTriangle needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but DetTriangle needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + matrix_n_ = input_shape[input_shape.size() - 1]; + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + if (output_size_ != input_size_ / matrix_n_ / matrix_n_) { + MS_LOG(ERROR) << "The output shape is wrong."; + return false; + } + if (input_shape[input_shape.size() - 2] != input_shape[input_shape.size() - 1]) { + MS_LOG(ERROR) << "The maxtices should be in shape of square."; + return false; + } + fill_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("fill_mode")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + size_t matrix_n_; + int fill_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 898c6d3b61..ded5b6dd18 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, CusMatMulCubeDenseRight, - CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky) + CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, DetTriangle) from .sparse_ops import SparseToDense from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py index 9cca988955..a757a37fca 100644 --- a/mindspore/ops/operations/_thor_ops.py +++ b/mindspore/ops/operations/_thor_ops.py @@ -636,3 +636,22 @@ class Cholesky(PrimitiveWithInfer): def infer_dtype(self, x1_dtype): validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) return x1_dtype + +class DetTriangle(PrimitiveWithInfer): + """ + Calculate the determinant of triangle matrices + """ + @prim_attr_register + def __init__(self, fill_mode=0): + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.fill_mode = fill_mode + self.add_prim_attr('fill_mode', self.fill_mode) + + def infer_shape(self, x1_shape): + out_shape = x1_shape + del out_shape[-2:] + return out_shape + + def infer_dtype(self, x1_dtype): + validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + return x1_dtype diff --git a/tests/st/ops/gpu/test_determinant_triangle.py b/tests/st/ops/gpu/test_determinant_triangle.py new file mode 100644 index 0000000000..3faffb5785 --- /dev/null +++ b/tests/st/ops/gpu/test_determinant_triangle.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, fill_mode=0): + super(Net, self).__init__() + self.det_triangle = P.DetTriangle(fill_mode=fill_mode) + + def construct(self, x): + return self.det_triangle(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_1D(): + fill_mode = 0 + input_x = np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]).astype(np.float32) + net = Net(fill_mode=fill_mode) + tx = Tensor(input_x, mstype.float32) + output = net(tx) + assert output == 18