diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc index d67cbcc027..4f7a0c6641 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc @@ -20,54 +20,38 @@ namespace mindspore { namespace kernel { -void HSigmoidCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void HSigmoidCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0); - if (dtype_ == kTypeUnknown) { - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - } for (const uint64_t &d : x_shape_) { tensor_size_ *= d; } - - launch_map_[kNumberTypeInt8] = &HSigmoidCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt16] = &HSigmoidCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt32] = &HSigmoidCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt64] = &HSigmoidCPUKernel::LaunchKernel; - launch_map_[kNumberTypeFloat32] = &HSigmoidCPUKernel::LaunchKernel; - - auto iter = launch_map_.find(dtype_); - if (iter != launch_map_.end()) { - launch_func_ = iter->second; - } else { - MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for HSigmoid kernel on CPU."; - } -} - -bool HSigmoidCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - launch_func_(this, inputs, outputs); - return true; } template -void HSigmoidCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { +bool HSigmoidCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto x = reinterpret_cast(inputs[0]->addr); auto y = reinterpret_cast(outputs[0]->addr); - for (uint64_t i = 0; i < tensor_size_; ++i) { - if (x[i] <= -3) { - y[i] = 0; - } else if (x[i] >= 3) { - y[i] = 1; - } else { - y[i] = (x[i] + 3) / 6; + auto task = [&](size_t start, size_t end) { + for (uint64_t i = start; i < end; ++i) { + if (x[i] <= -3) { + y[i] = 0; + } else if (x[i] >= 3) { + y[i] = 1; + } else { + y[i] = (x[i] + 3) / 6; + } } - } + }; + CPUKernelUtils::ParallelFor(task, tensor_size_); + return true; } -void HSigmoidCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void HSigmoidCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoidCPUKernel needs 1 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.h index 1185c14121..ea5f68114c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class HSigmoidCPUKernel : public CPUKernel { public: HSigmoidCPUKernel() = default; @@ -34,34 +35,26 @@ class HSigmoidCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); std::vector x_shape_; - TypeId dtype_{kTypeUnknown}; - using TypeKernel = std::function &inputs, - const std::vector &outputs)>; - std::unordered_map launch_map_; - TypeKernel launch_func_; uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - HSigmoidCPUKernel); +MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + HSigmoidCPUKernel, int8_t); -MS_REG_CPU_KERNEL(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - HSigmoidCPUKernel); +MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + HSigmoidCPUKernel, int16_t); -MS_REG_CPU_KERNEL(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - HSigmoidCPUKernel); +MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + HSigmoidCPUKernel, int); -MS_REG_CPU_KERNEL(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - HSigmoidCPUKernel); +MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + HSigmoidCPUKernel, int64_t); -MS_REG_CPU_KERNEL(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - HSigmoidCPUKernel); +MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + HSigmoidCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc index 7877f20eb1..27facaa385 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc @@ -20,54 +20,37 @@ namespace mindspore { namespace kernel { -void HSigmoidGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void HSigmoidGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0); - if (dtype_ == kTypeUnknown) { - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - } for (const uint64_t &d : x_shape_) { tensor_size_ *= d; } - - launch_map_[kNumberTypeInt8] = &HSigmoidGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt16] = &HSigmoidGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt32] = &HSigmoidGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt64] = &HSigmoidGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeFloat32] = &HSigmoidGradCPUKernel::LaunchKernel; - - auto iter = launch_map_.find(dtype_); - if (iter != launch_map_.end()) { - launch_func_ = iter->second; - } else { - MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for HSigmoidGrad kernel on CPU."; - } -} - -bool HSigmoidGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - launch_func_(this, inputs, outputs); - return true; } template -void HSigmoidGradCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { +bool HSigmoidGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto dy = reinterpret_cast(inputs[0]->addr); auto x = reinterpret_cast(inputs[1]->addr); auto out = reinterpret_cast(outputs[0]->addr); - for (uint64_t i = 0; i < tensor_size_; ++i) { - if (x[i] <= -3 || x[i] >= 3) { - out[i] = 0; - } else { - out[i] = dy[i] / 6; + auto task = [&](size_t start, size_t end) { + for (uint64_t i = start; i < end; ++i) { + if (x[i] <= -3 || x[i] >= 3) { + out[i] = 0; + } else { + out[i] = dy[i] / 6; + } } - } + }; + CPUKernelUtils::ParallelFor(task, tensor_size_); + return true; } -void HSigmoidGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void HSigmoidGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoidGradCPUKernel needs 2 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.h index 9d434232a6..dbb73cc8cc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class HSigmoidGradCPUKernel : public CPUKernel { public: HSigmoidGradCPUKernel() = default; @@ -34,43 +35,35 @@ class HSigmoidGradCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); std::vector x_shape_; - TypeId dtype_{kTypeUnknown}; - using TypeKernel = std::function &inputs, - const std::vector &outputs)>; - std::unordered_map launch_map_; - TypeKernel launch_func_; uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - HSigmoidGradCPUKernel); + HSigmoidGradCPUKernel, int8_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - HSigmoidGradCPUKernel); + HSigmoidGradCPUKernel, int16_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - HSigmoidGradCPUKernel); + HSigmoidGradCPUKernel, int); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - HSigmoidGradCPUKernel); + HSigmoidGradCPUKernel, int64_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - HSigmoidGradCPUKernel); + HSigmoidGradCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc index 5a18e7abcb..cdfacd0421 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc @@ -20,54 +20,38 @@ namespace mindspore { namespace kernel { -void HSwishCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void HSwishCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0); - if (dtype_ == kTypeUnknown) { - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - } for (const uint64_t &d : x_shape_) { tensor_size_ *= d; } - - launch_map_[kNumberTypeInt8] = &HSwishCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt16] = &HSwishCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt32] = &HSwishCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt64] = &HSwishCPUKernel::LaunchKernel; - launch_map_[kNumberTypeFloat32] = &HSwishCPUKernel::LaunchKernel; - - auto iter = launch_map_.find(dtype_); - if (iter != launch_map_.end()) { - launch_func_ = iter->second; - } else { - MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for HSwish kernel on CPU."; - } -} - -bool HSwishCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - launch_func_(this, inputs, outputs); - return true; } template -void HSwishCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { +bool HSwishCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto x = reinterpret_cast(inputs[0]->addr); auto y = reinterpret_cast(outputs[0]->addr); - for (uint64_t i = 0; i < tensor_size_; ++i) { - if (x[i] <= -3) { - y[i] = 0; - } else if (x[i] >= 3) { - y[i] = x[i]; - } else { - y[i] = x[i] * (x[i] + 3) / 6; + auto task = [&](size_t start, size_t end) { + for (uint64_t i = start; i < end; ++i) { + if (x[i] <= -3) { + y[i] = 0; + } else if (x[i] >= 3) { + y[i] = x[i]; + } else { + y[i] = x[i] * (x[i] + 3) / 6; + } } - } + }; + CPUKernelUtils::ParallelFor(task, tensor_size_); + return true; } -void HSwishCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void HSwishCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSwishCPUKernel needs 1 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.h index 71a16efe12..c8cc353890 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class HSwishCPUKernel : public CPUKernel { public: HSwishCPUKernel() = default; @@ -34,30 +35,26 @@ class HSwishCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); std::vector x_shape_; - TypeId dtype_{kTypeUnknown}; - using TypeKernel = std::function &inputs, - const std::vector &outputs)>; - std::unordered_map launch_map_; - TypeKernel launch_func_; uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), HSwishCPUKernel); +MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), HSwishCPUKernel, + int8_t); -MS_REG_CPU_KERNEL(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), HSwishCPUKernel); +MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + HSwishCPUKernel, int16_t); -MS_REG_CPU_KERNEL(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), HSwishCPUKernel); +MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + HSwishCPUKernel, int); -MS_REG_CPU_KERNEL(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), HSwishCPUKernel); +MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + HSwishCPUKernel, int64_t); -MS_REG_CPU_KERNEL(HSwish, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - HSwishCPUKernel); +MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + HSwishCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc index 939757b8b7..794722413e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc @@ -20,55 +20,39 @@ namespace mindspore { namespace kernel { -void HSwishGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void HSwishGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0); - if (dtype_ == kTypeUnknown) { - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - } for (const uint64_t &d : x_shape_) { tensor_size_ *= d; } - - launch_map_[kNumberTypeInt8] = &HSwishGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt16] = &HSwishGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt32] = &HSwishGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeInt64] = &HSwishGradCPUKernel::LaunchKernel; - launch_map_[kNumberTypeFloat32] = &HSwishGradCPUKernel::LaunchKernel; - - auto iter = launch_map_.find(dtype_); - if (iter != launch_map_.end()) { - launch_func_ = iter->second; - } else { - MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for HSwishGrad kernel on CPU."; - } -} - -bool HSwishGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - launch_func_(this, inputs, outputs); - return true; } template -void HSwishGradCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { +bool HSwishGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto dy = reinterpret_cast(inputs[0]->addr); auto x = reinterpret_cast(inputs[1]->addr); auto out = reinterpret_cast(outputs[0]->addr); - for (uint64_t i = 0; i < tensor_size_; ++i) { - if (x[i] <= -3) { - out[i] = 0; - } else if (x[i] >= 3) { - out[i] = dy[i]; - } else { - out[i] = dy[i] * (2 * x[i] + 3) / 6; + auto task = [&](size_t start, size_t end) { + for (uint64_t i = start; i < end; ++i) { + if (x[i] <= -3) { + out[i] = 0; + } else if (x[i] >= 3) { + out[i] = dy[i]; + } else { + out[i] = dy[i] * (2 * x[i] + 3) / 6; + } } - } + }; + CPUKernelUtils::ParallelFor(task, tensor_size_); + return true; } -void HSwishGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void HSwishGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSwishGradCPUKernel needs 2 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.h index 6d163cbe1e..3baff359ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class HSwishGradCPUKernel : public CPUKernel { public: HSwishGradCPUKernel() = default; @@ -34,43 +35,35 @@ class HSwishGradCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); std::vector x_shape_; - TypeId dtype_{kTypeUnknown}; - using TypeKernel = std::function &inputs, - const std::vector &outputs)>; - std::unordered_map launch_map_; - TypeKernel launch_func_; uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - HSwishGradCPUKernel); + HSwishGradCPUKernel, int8_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - HSwishGradCPUKernel); + HSwishGradCPUKernel, int16_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - HSwishGradCPUKernel); + HSwishGradCPUKernel, int); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - HSwishGradCPUKernel); + HSwishGradCPUKernel, int64_t); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - HSwishGradCPUKernel); + HSwishGradCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc index b47a764740..1a72aabae9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc @@ -19,50 +19,45 @@ namespace mindspore { namespace kernel { -void SmoothL1LossCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void SmoothL1LossCPUKernel::InitKernel(const CNodePtr &kernel_node) { beta_ = AnfAlgo::GetNodeAttr(kernel_node, "beta"); CheckParam(kernel_node); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (const uint64_t &d : x_shape) { tensor_size_ *= d; } } -bool SmoothL1LossCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } - return true; -} - template -void SmoothL1LossCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { +bool SmoothL1LossCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto predict_addr = reinterpret_cast(inputs[0]->addr); auto target_addr = reinterpret_cast(inputs[1]->addr); auto result_addr = reinterpret_cast(outputs[0]->addr); T zero = (T)0.0; T half = (T)0.5; T beta = (T)beta_; - for (uint64_t i = 0; i < tensor_size_; ++i) { - T diff = predict_addr[i] - target_addr[i]; - if (diff < zero) { - diff = -diff; + auto task = [&](size_t start, size_t end) { + for (uint64_t i = start; i < end; ++i) { + T diff = predict_addr[i] - target_addr[i]; + if (diff < zero) { + diff = -diff; + } + if (diff < beta) { + result_addr[i] = half * diff * diff / beta; + } else { + result_addr[i] = diff - (half * beta); + } } - if (diff < beta) { - result_addr[i] = half * diff * diff / beta; - } else { - result_addr[i] = diff - (half * beta); - } - } + }; + CPUKernelUtils::ParallelFor(task, tensor_size_); + return true; } -void SmoothL1LossCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void SmoothL1LossCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossCPUKernel needs 2 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h index 321322a3de..4c2c34156b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class SmoothL1LossCPUKernel : public CPUKernel { public: SmoothL1LossCPUKernel() = default; @@ -34,9 +35,6 @@ class SmoothL1LossCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); float beta_ = 1.0; @@ -44,15 +42,15 @@ class SmoothL1LossCPUKernel : public CPUKernel { uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( SmoothL1Loss, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SmoothL1LossCPUKernel); + SmoothL1LossCPUKernel, float16); -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( SmoothL1Loss, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SmoothL1LossCPUKernel); + SmoothL1LossCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc index cdb64f4d5c..1d7ca7ff31 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc @@ -19,30 +19,20 @@ namespace mindspore { namespace kernel { -void SmoothL1LossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void SmoothL1LossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { beta_ = AnfAlgo::GetNodeAttr(kernel_node, "beta"); CheckParam(kernel_node); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (const uint64_t &d : x_shape) { tensor_size_ *= d; } } -bool SmoothL1LossGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } - return true; -} - template -void SmoothL1LossGradCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { +bool SmoothL1LossGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { auto predict_addr = reinterpret_cast(inputs[0]->addr); auto target_addr = reinterpret_cast(inputs[1]->addr); auto dloss_addr = reinterpret_cast(inputs[2]->addr); @@ -58,9 +48,11 @@ void SmoothL1LossGradCPUKernel::LaunchKernel(const std::vector &inpu result_addr[i] = (diff / beta) * dloss_addr[i]; } } + return true; } -void SmoothL1LossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { +template +void SmoothL1LossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 3) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossGradCPUKernel needs 3 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h index a703e33b6e..bd016ba836 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +template class SmoothL1LossGradCPUKernel : public CPUKernel { public: SmoothL1LossGradCPUKernel() = default; @@ -34,31 +35,27 @@ class SmoothL1LossGradCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); float beta_ = 1.0; - TypeId dtype_{kTypeUnknown}; uint64_t tensor_size_ = 1; }; -MS_REG_CPU_KERNEL(SmoothL1LossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SmoothL1LossGradCPUKernel); - -MS_REG_CPU_KERNEL(SmoothL1LossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SmoothL1LossGradCPUKernel); +MS_REG_CPU_KERNEL_T(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SmoothL1LossGradCPUKernel, float16); + +MS_REG_CPU_KERNEL_T(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGradCPUKernel, float); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py b/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py deleted file mode 100644 index f72dbe4590..0000000000 --- a/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.composite import GradOperation - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class Net(nn.Cell): - def __init__(self, sigma=1.0): - super(Net, self).__init__() - self.SmoothL1Loss = P.SmoothL1Loss(sigma) - - def construct(self, pred, gt): - return self.SmoothL1Loss(pred, gt) - - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = GradOperation(get_all=True, sens_param=True) - self.network = network - - def construct(self, pred, gt, dout): - return self.grad(self.network)(pred, gt, dout) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_net(): - pred = np.random.randn(2, 4).astype(np.float32) - gt = np.random.randn(2, 4).astype(np.float32) - dout = np.random.randn(2, 4).astype(np.float32) - smooth_l1_loss_grad = Grad(Net()) - output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout)) - print("------------- input ---------------") - print("predict:\n", pred) - print("grount truth:\n", gt) - print("dout:\n", dout) - print("------------- output ---------------") - print("predict grad:\n", output[0].asnumpy()) diff --git a/tests/st/ops/cpu/test_smooth_l1_loss_op.py b/tests/st/ops/cpu/test_smooth_l1_loss_op.py deleted file mode 100644 index f0fe298ff7..0000000000 --- a/tests/st/ops/cpu/test_smooth_l1_loss_op.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class Net(nn.Cell): - def __init__(self, sigma=1.0): - super(Net, self).__init__() - self.SmoothL1Loss = P.SmoothL1Loss(sigma) - - def construct(self, pred, gt): - return self.SmoothL1Loss(pred, gt) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_net(): - pred = np.random.randn(2, 4).astype(np.float32) - gt = np.random.randn(2, 4).astype(np.float32) - smooth_l1_loss = Net() - loss = smooth_l1_loss(Tensor(pred), Tensor(gt)) - print("------------- input ---------------") - print("predict:\n", pred) - print("grount truth:\n", gt) - print("------------- output ---------------") - print("loss:\n", loss.asnumpy()) diff --git a/tests/st/ops/cpu/test_smoothl1loss_op.py b/tests/st/ops/cpu/test_smoothl1loss_op.py new file mode 100644 index 0000000000..3c6c0f70c3 --- /dev/null +++ b/tests/st/ops/cpu/test_smoothl1loss_op.py @@ -0,0 +1,119 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C + +def smoothl1loss(beta): + np.random.seed(42) + prediction = np.random.randn(20).astype(np.float32) + target = np.random.randn(20).astype(np.float32) + + net = nn.SmoothL1Loss(beta) + return net(Tensor(prediction), Tensor(target)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_smoothl1loss(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True) + + epsilon = 1e-6 + + beta = 1.0 + loss = smoothl1loss(beta) + expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, + 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, + 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, + 0.08826803, 1.109165] + diff = np.absolute(loss.asnumpy() - np.array(expect)) + assert(diff < epsilon).all() + + beta = 1 / 9 + loss = smoothl1loss(beta) + expect = [0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504, + 2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524, + 0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619, + 0.3646064, 1.5536094] + diff = np.absolute(loss.asnumpy() - np.array(expect)) + assert(diff < epsilon).all() + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens): + gout = self.grad(self.network)(x1, x2, sens) + return gout + + +def smoothl1loss_grad(beta): + np.random.seed(42) + prediction = np.random.randn(20).astype(np.float32) + target = np.random.randn(20).astype(np.float32) + sens = np.random.randn(20).astype(np.float32) + + net = nn.SmoothL1Loss(beta) + grad = Grad(net) + return grad(Tensor(prediction), Tensor(target), Tensor(sens)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_smoothl1loss_grad(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True) + + epsilon = 1e-6 + + beta = 1.0 + dx = smoothl1loss_grad(beta) + dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, + 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, + 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, + 0.61330026, 0.83921754, -0.3092124, 0.1391843, -0.9755451] + + dx2_expect = [0.71552587, -0.01499678, 0.06709455, 0.30110368, 0.45868093, + -0.24838912, 0.46063876, -0.41411355, -0.04507046, 1.4708229, + -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, + -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] + + diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) + diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) + assert(diff1 < epsilon).all() + assert(diff2 < epsilon).all() + + beta = 1 / 9 + dx = smoothl1loss_grad(beta) + dx1_expect = [-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522, + 0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402, + 0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995, + 0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451] + dx2_expect = [0.73846656, -0.13497104, 0.11564828, 0.30110368, 1.478522, + -0.7198442, 0.46063876, -1.0571222, -0.3436183, 1.7630402, + -0.32408398, -0.38508227, 0.676922, 0.6116763, 1.0309995, + -0.93128014, -0.83921754, 0.3092124, -0.33126342, 0.9755451] + + diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) + diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) + assert(diff1 < epsilon).all() + assert(diff2 < epsilon).all()