!14183 Support SparseTensorDenseMatmul for CPU
	
		
	
				
					
				
			From: @xuguoyang5566 Reviewed-by: Signed-off-by:pull/14183/MERGE
						commit
						efb53fb9c0
					
				| @ -0,0 +1,75 @@ | ||||
| /**
 | ||||
|  * Copyright 2020-2021 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #include "backend/kernel_compiler/cpu/sparse_tensor_dense_matmul_cpu_kernel.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace kernel { | ||||
| template <typename I, typename T> | ||||
| void SparseTensorDenseMatmulCPUKernel<I, T>::InitKernel(const CNodePtr &kernel_node) { | ||||
|   output_size_ = 1; | ||||
|   auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
|   for (auto &dim : output_shape) { | ||||
|     output_size_ *= dim; | ||||
|   } | ||||
| 
 | ||||
|   aValues_size_ = 1; | ||||
|   auto aValues_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | ||||
|   for (auto &dim : aValues_shape) { | ||||
|     aValues_size_ *= dim; | ||||
|   } | ||||
| 
 | ||||
|   b_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 3); | ||||
|   output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | ||||
| } | ||||
| 
 | ||||
| template <typename I, typename T> | ||||
| bool SparseTensorDenseMatmulCPUKernel<I, T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
|                                                     const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
|                                                     const std::vector<kernel::AddressPtr> &outputs) { | ||||
|   auto a_indices = reinterpret_cast<I *>(inputs[0]->addr); | ||||
|   auto a_values = reinterpret_cast<T *>(inputs[1]->addr); | ||||
|   auto b = reinterpret_cast<T *>(inputs[3]->addr); | ||||
|   auto out = reinterpret_cast<T *>(outputs[0]->addr); | ||||
| 
 | ||||
|   memset(out, 0, output_size_); | ||||
| 
 | ||||
|   const size_t nnz = aValues_size_; | ||||
|   const size_t rhs_right = b_shape_[1]; | ||||
|   const size_t lhs_right = b_shape_[0]; | ||||
| 
 | ||||
|   for (size_t i = 0; i < nnz; ++i) { | ||||
|     const size_t m = a_indices[i * 2]; | ||||
|     const size_t k = a_indices[i * 2 + 1]; | ||||
| 
 | ||||
|     if (k > lhs_right) { | ||||
|       MS_LOG(ERROR) << "Invalid value: k: " << k << ", lhs_right: " << lhs_right; | ||||
|       return false; | ||||
|     } | ||||
|     if (m > output_shape_[0]) { | ||||
|       MS_LOG(ERROR) << "Invalid value: m: " << m << ", output_shape: " << output_shape_[0]; | ||||
|       return false; | ||||
|     } | ||||
| 
 | ||||
|     for (size_t n = 0; n < rhs_right; ++n) { | ||||
|       const float b_value = b[k * lhs_right + n]; | ||||
|       out[m * output_shape_[0] + n] += a_values[i] * b_value; | ||||
|     } | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| }  // namespace kernel
 | ||||
| }  // namespace mindspore
 | ||||
| @ -0,0 +1,243 @@ | ||||
| /**
 | ||||
|  * Copyright 2020-2021 Huawei Technologies Co., Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_TENSOR_DENSE_MATMUL_CPU_KERNEL_H_ | ||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_TENSOR_DENSE_MATMUL_CPU_KERNEL_H_ | ||||
| 
 | ||||
| #include <vector> | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | ||||
| 
 | ||||
| namespace mindspore { | ||||
| namespace kernel { | ||||
| template <typename I, typename T> | ||||
| class SparseTensorDenseMatmulCPUKernel : public CPUKernel { | ||||
|  public: | ||||
|   SparseTensorDenseMatmulCPUKernel() = default; | ||||
|   ~SparseTensorDenseMatmulCPUKernel() override = default; | ||||
| 
 | ||||
|   void InitKernel(const CNodePtr &kernel_node) override; | ||||
| 
 | ||||
|   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
|               const std::vector<AddressPtr> &outputs) override; | ||||
| 
 | ||||
|  private: | ||||
|   std::vector<size_t> output_shape_; | ||||
|   std::vector<size_t> b_shape_; | ||||
|   size_t output_size_; | ||||
|   size_t aValues_size_; | ||||
| }; | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeBool) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeBool) | ||||
|                         .AddOutputAttr(kNumberTypeBool), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, bool); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt8) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt8) | ||||
|                         .AddOutputAttr(kNumberTypeUInt8), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, uint8_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt16) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt16) | ||||
|                         .AddOutputAttr(kNumberTypeUInt16), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, uint16_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt32) | ||||
|                         .AddOutputAttr(kNumberTypeUInt32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, uint32_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt64) | ||||
|                         .AddOutputAttr(kNumberTypeUInt64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, uint64_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt8) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt8) | ||||
|                         .AddOutputAttr(kNumberTypeInt8), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, int8_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt16) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt16) | ||||
|                         .AddOutputAttr(kNumberTypeInt16), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, int16_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddOutputAttr(kNumberTypeInt32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, int32_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddOutputAttr(kNumberTypeInt64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, int64_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat32) | ||||
|                         .AddOutputAttr(kNumberTypeFloat32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, float); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat64) | ||||
|                         .AddOutputAttr(kNumberTypeFloat64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int32_t, double); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeBool) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeBool) | ||||
|                         .AddOutputAttr(kNumberTypeBool), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, bool); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeUInt8) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt8) | ||||
|                         .AddOutputAttr(kNumberTypeUInt8), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, uint8_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeUInt16) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt16) | ||||
|                         .AddOutputAttr(kNumberTypeUInt16), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, uint16_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeUInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt32) | ||||
|                         .AddOutputAttr(kNumberTypeUInt32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, uint32_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeUInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeUInt64) | ||||
|                         .AddOutputAttr(kNumberTypeUInt64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, uint64_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt8) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt8) | ||||
|                         .AddOutputAttr(kNumberTypeInt8), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, int8_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt16) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt16) | ||||
|                         .AddOutputAttr(kNumberTypeInt16), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, int16_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddOutputAttr(kNumberTypeInt32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, int32_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddOutputAttr(kNumberTypeInt64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, int64_t); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeFloat32) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat32) | ||||
|                         .AddOutputAttr(kNumberTypeFloat32), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, float); | ||||
| 
 | ||||
| MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul, | ||||
|                       KernelAttr() | ||||
|                         .AddInputAttr(kNumberTypeInt64) | ||||
|                         .AddInputAttr(kNumberTypeFloat64) | ||||
|                         .AddInputAttr(kNumberTypeInt32) | ||||
|                         .AddInputAttr(kNumberTypeFloat64) | ||||
|                         .AddOutputAttr(kNumberTypeFloat64), | ||||
|                       SparseTensorDenseMatmulCPUKernel, int64_t, double); | ||||
| 
 | ||||
| }  // namespace kernel
 | ||||
| }  // namespace mindspore
 | ||||
| #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_
 | ||||
| @ -0,0 +1,53 @@ | ||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| 
 | ||||
| import numpy as np | ||||
| import mindspore as ms | ||||
| import mindspore.context as context | ||||
| import mindspore.nn as nn | ||||
| from mindspore import Tensor | ||||
| from mindspore import SparseTensor | ||||
| 
 | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | ||||
| 
 | ||||
| class NetSparseDenseMatmul(nn.Cell): | ||||
|     def __init__(self): | ||||
|         super(NetSparseDenseMatmul, self).__init__() | ||||
|         self.matmul = nn.SparseTensorDenseMatmul() | ||||
| 
 | ||||
|     def construct(self, indices, values, dens_shape, dt): | ||||
|         return self.matmul(indices, values, dens_shape, dt) | ||||
| 
 | ||||
| class NetSparseTensor(nn.Cell): | ||||
|     def __init__(self, dense_shape): | ||||
|         super(NetSparseTensor, self).__init__() | ||||
|         self.dense_shape = dense_shape | ||||
|     def construct(self, indices, values): | ||||
|         x = SparseTensor(indices, values, self.dense_shape) | ||||
|         return x.values, x.indices, x.dense_shape | ||||
| 
 | ||||
| def test_sparse_tensor_dense_matmul(): | ||||
|     indices = Tensor([[0, 1], [1, 1]]) | ||||
|     values = Tensor([5, 5], dtype=ms.float32) | ||||
|     dens_shape = (3, 3) | ||||
|     spMatrix = np.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]], dtype=np.float32) | ||||
|     dsMatrix = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32) | ||||
|     test_SparseDenseMatmul = NetSparseDenseMatmul() | ||||
| 
 | ||||
|     out_ms = test_SparseDenseMatmul(indices, values, dens_shape, Tensor(dsMatrix)) | ||||
|     out_np = np.matmul(spMatrix, dsMatrix) | ||||
|     error = np.ones(shape=dsMatrix.shape) * 10e-6 | ||||
|     diff = out_ms.asnumpy() - out_np | ||||
|     assert np.all(diff < error) | ||||
					Loading…
					
					
				
		Reference in new issue