From c62baec9a4803f1fd0245f17884a8978dc89beff Mon Sep 17 00:00:00 2001 From: zhaoting Date: Fri, 19 Feb 2021 16:21:53 +0800 Subject: [PATCH] add parallel for some CPU ops --- .../kernel_compiler/cpu/adam_cpu_kernel.cc | 48 +- .../kernel_compiler/cpu/adam_cpu_kernel.h | 2 +- .../cpu/arithmetic_cpu_kernel.cc | 504 +++++++++--------- .../cpu/arithmetic_cpu_kernel.h | 38 +- .../cpu/arithmetic_self_cpu_kernel.cc | 318 +++++------ .../cpu/arithmetic_self_cpu_kernel.h | 1 - .../kernel_compiler/cpu/cast_cpu_kernel.cc | 216 +------- .../kernel_compiler/cpu/cast_cpu_kernel.h | 455 ++++++++++------ .../cpu/layer_norm_cpu_kernel.cc | 53 +- .../cpu/layer_norm_grad_cpu_kernel.cc | 100 ++-- .../cpu/unsorted_segment_sum_cpu_kernel.cc | 37 +- 11 files changed, 907 insertions(+), 865 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc index c5e28c59a8..591a964782 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "backend/kernel_compiler/cpu/adam_cpu_kernel.h" #include -#include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -25,16 +24,19 @@ namespace mindspore { namespace kernel { template void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient, - size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - m[i] += (gradient[i] - m[i]) * (1 - beta1); - v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); - if (use_nesterov) { - var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon); - } else { - var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); + size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + m[i] += (gradient[i] - m[i]) * (1 - beta1); + v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); + if (use_nesterov) { + var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon); + } else { + var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); + } } - } + }; + CPUKernelUtils::ParallelFor(task, size); } void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -84,31 +86,7 @@ bool AdamCPUKernel::Launch(const std::vector &inputs, // multithreading size_t lens = inputs[0]->size > 0 ? static_cast(inputs[0]->size / sizeof(float)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return false; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return false; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam, this, var, m, v, new_lr, beta1, beta2, epsilon, - gradient, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } - + LaunchAdam(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h index 79e42695bf..61a02ee93a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h @@ -29,7 +29,7 @@ class AdamCPUKernel : public CPUKernel { ~AdamCPUKernel() override = default; template void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient, - size_t start, size_t end); + size_t size); void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 82920c84ee..262fa0024f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include #include -#include #include #include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -23,227 +22,285 @@ namespace mindspore { namespace kernel { template -void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = input1[i] + input2[i]; - input1[i] = out[i]; - } +void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = input1[i] + input2[i]; + input1[i] = out[i]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] + input2[idx[1]]; - } +void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] + input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] - input2[idx[1]]; - } +void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] - input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] * input2[idx[1]]; - } +void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] * input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - auto dividend = input1[idx[0]]; - auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { - out[i] = std::numeric_limits::quiet_NaN(); +void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto dividend = input1[idx[0]]; + auto divisor = input2[idx[1]]; + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } continue; } - if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); - } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); - } - continue; + out[i] = dividend / divisor; } - out[i] = dividend / divisor; - } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - auto dividend = input1[idx[0]]; - auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { - out[i] = std::numeric_limits::quiet_NaN(); +void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto dividend = input1[idx[0]]; + auto divisor = input2[idx[1]]; + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } continue; } - if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); - } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); - } - continue; + out[i] = dividend / divisor; } - out[i] = dividend / divisor; - } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - auto dividend = input1[idx[0]]; - auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { - out[i] = std::numeric_limits::quiet_NaN(); +void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto dividend = input1[idx[0]]; + auto divisor = input2[idx[1]]; + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } continue; } - if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); - } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); - } - continue; + out[i] = floor(dividend / divisor); } - out[i] = floor(dividend / divisor); - } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - auto x = static_cast(input1[idx[0]]); - auto y = static_cast(input2[idx[1]]); - auto data_div = x / y; - auto data_div_min = data_div < 0.0 ? data_div : 0.0; - auto data_div_max = data_div > 0.0 ? data_div : 0.0; - auto data_div_max_floor = floor(data_div_max); - auto data_div_min_ceil = ceil(data_div_min); - auto data_div_res = data_div_max_floor + data_div_min_ceil; - out[i] = static_cast(x - data_div_res * y); - } +void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto x = static_cast(input1[idx[0]]); + auto y = static_cast(input2[idx[1]]); + auto data_div = x / y; + auto data_div_min = data_div < 0.0 ? data_div : 0.0; + auto data_div_max = data_div > 0.0 ? data_div : 0.0; + auto data_div_max_floor = floor(data_div_max); + auto data_div_min_ceil = ceil(data_div_min); + auto data_div_res = data_div_max_floor + data_div_min_ceil; + out[i] = static_cast(x - data_div_res * y); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - auto x = static_cast(input1[idx[0]]); - auto y = static_cast(input2[idx[1]]); - out[i] = static_cast(std::pow(x, y)); - } +void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto x = static_cast(input1[idx[0]]); + auto y = static_cast(input2[idx[1]]); + out[i] = static_cast(std::pow(x, y)); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] < input2[idx[1]]; - } +void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] < input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] == input2[idx[1]]; - } +void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] == input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] != input2[idx[1]]; - } +void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] != input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] && input2[idx[1]]; - } +void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] && input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] || input2[idx[1]]; - } +void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] || input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - T diff = input1[idx[0]] - input2[idx[1]]; - out[i] = diff * diff; - } +void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + T diff = input1[idx[0]] - input2[idx[1]]; + out[i] = diff * diff; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] > input2[idx[1]]; - } +void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] > input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] >= input2[idx[1]]; - } +void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] >= input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = input1[idx[0]] <= input2[idx[1]]; - } +void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] <= input2[idx[1]]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - std::vector idx; - GenIndex(i, &idx); - out[i] = atan2(input1[idx[0]], input2[idx[1]]); - } +void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = atan2(input1[idx[0]], input2[idx[1]]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } + static const std::map kArithmeticBinOpTypeMap = { {prim::kPrimGreater->name(), GREATER}, {prim::kPrimAdd->name(), ADD}, @@ -352,49 +409,25 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector &input T *input1 = reinterpret_cast(inputs[0]->addr); T *input2 = reinterpret_cast(inputs[1]->addr); bool *output = reinterpret_cast(outputs[0]->addr); - size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(bool)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - if (operate_type_ == LESS) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Less, this, input1, input2, output, start, end)); - } else if (operate_type_ == EQUAL) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal, this, input1, input2, output, start, end)); - } else if (operate_type_ == NOTEQUAL) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual, this, input1, input2, output, start, end)); - } else if (operate_type_ == GREATER) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater, this, input1, input2, output, start, end)); - } else if (operate_type_ == GREATEREQUAL) { - threads.emplace_back( - std::thread(&ArithmeticCPUKernel::GreaterEqual, this, input1, input2, output, start, end)); - } else if (operate_type_ == LESSEQUAL) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual, this, input1, input2, output, start, end)); - } else if (operate_type_ == LOGICALAND) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalAnd, this, input1, input2, output, start, end)); - } else if (operate_type_ == LOGICALOR) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalOr, this, input1, input2, output, start, end)); - } else { - MS_LOG(EXCEPTION) << "Not support " << operate_type_; - } - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); + if (operate_type_ == LESS) { + Less(input1, input2, output, lens); + } else if (operate_type_ == EQUAL) { + Equal(input1, input2, output, lens); + } else if (operate_type_ == NOTEQUAL) { + NotEqual(input1, input2, output, lens); + } else if (operate_type_ == GREATER) { + Greater(input1, input2, output, lens); + } else if (operate_type_ == GREATEREQUAL) { + GreaterEqual(input1, input2, output, lens); + } else if (operate_type_ == LESSEQUAL) { + LessEqual(input1, input2, output, lens); + } else if (operate_type_ == LOGICALAND) { + LogicalAnd(input1, input2, output, lens); + } else if (operate_type_ == LOGICALOR) { + LogicalOr(input1, input2, output, lens); + } else { + MS_LOG(EXCEPTION) << "Not support " << operate_type_; } } @@ -409,53 +442,30 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co T *output = reinterpret_cast(outputs[0]->addr); size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - if (operate_type_ == ADD) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Add, this, input1, input2, output, start, end)); - } else if (operate_type_ == SUB) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Sub, this, input1, input2, output, start, end)); - } else if (operate_type_ == MUL) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul, this, input1, input2, output, start, end)); - } else if (operate_type_ == REALDIV) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv, this, input1, input2, output, start, end)); - } else if (operate_type_ == DIV) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div, this, input1, input2, output, start, end)); - } else if (operate_type_ == FLOORDIV) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::FloorDiv, this, input1, input2, output, start, end)); - } else if (operate_type_ == MOD) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod, this, input1, input2, output, start, end)); - } else if (operate_type_ == POW) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow, this, input1, input2, output, start, end)); - } else if (operate_type_ == ASSIGNADD) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd, this, input1, input2, output, start, end)); - } else if (operate_type_ == ATAN2) { - threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2, this, input1, input2, output, start, end)); - } else if (operate_type_ == SQUAREDDIFFERENCE) { - threads.emplace_back( - std::thread(&ArithmeticCPUKernel::SquaredDifference, this, input1, input2, output, start, end)); - } else { - MS_LOG(EXCEPTION) << "Not support " << operate_type_; - } - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); + if (operate_type_ == ADD) { + Add(input1, input2, output, lens); + } else if (operate_type_ == SUB) { + Sub(input1, input2, output, lens); + } else if (operate_type_ == MUL) { + Mul(input1, input2, output, lens); + } else if (operate_type_ == REALDIV) { + RealDiv(input1, input2, output, lens); + } else if (operate_type_ == DIV) { + Div(input1, input2, output, lens); + } else if (operate_type_ == FLOORDIV) { + FloorDiv(input1, input2, output, lens); + } else if (operate_type_ == MOD) { + Mod(input1, input2, output, lens); + } else if (operate_type_ == POW) { + Pow(input1, input2, output, lens); + } else if (operate_type_ == ASSIGNADD) { + AssignAdd(input1, input2, output, lens); + } else if (operate_type_ == ATAN2) { + Atan2(input1, input2, output, lens); + } else if (operate_type_ == SQUAREDDIFFERENCE) { + SquaredDifference(input1, input2, output, lens); + } else { + MS_LOG(EXCEPTION) << "Not support " << operate_type_; } } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index cc2ab1a4a3..8a707065c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -40,43 +40,43 @@ class ArithmeticCPUKernel : public CPUKernel { private: void GenIndex(size_t num, std::vector *tmp); template - void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Sub(const T *input1, const T *input2, T *out, size_t size); template - void Add(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Add(const T *input1, const T *input2, T *out, size_t size); template - void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Mul(const T *input1, const T *input2, T *out, size_t size); template - void RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); + void RealDiv(const T *input1, const T *input2, T *out, size_t size); template - void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Div(const T *input1, const T *input2, T *out, size_t size); template - void FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); + void FloorDiv(const T *input1, const T *input2, T *out, size_t size); template - void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Mod(const T *input1, const T *input2, T *out, size_t size); template - void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Pow(const T *input1, const T *input2, T *out, size_t size); template - void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); + void AssignAdd(T *input1, const T *input2, T *out, size_t size); template - void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end); + void Atan2(const T *input1, const T *input2, T *out, size_t size); template - void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void Less(const T *input1, const T *input2, bool *out, size_t size); template - void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void Equal(const T *input1, const T *input2, bool *out, size_t size); template - void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void NotEqual(const T *input1, const T *input2, bool *out, size_t size); template - void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end); + void SquaredDifference(const T *input1, const T *input2, T *out, size_t size); template - void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void Greater(const T *input1, const T *input2, bool *out, size_t size); template - void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void GreaterEqual(const T *input1, const T *input2, bool *out, size_t size); template - void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void LessEqual(const T *input1, const T *input2, bool *out, size_t size); template - void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void LogicalAnd(const T *input1, const T *input2, bool *out, size_t size); template - void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end); + void LogicalOr(const T *input1, const T *input2, bool *out, size_t size); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index 04bfd4997e..a73682a888 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -24,152 +24,212 @@ namespace mindspore { namespace kernel { namespace { template -void Square(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = in[i] * in[i]; - } +void Square(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = in[i] * in[i]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Sign(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - if (in[i] < 0) { - out[i] = -1; - } else if (in[i] > 0) { - out[i] = 1; - } else { - out[i] = 0; +void Sign(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + if (in[i] < 0) { + out[i] = -1; + } else if (in[i] > 0) { + out[i] = 1; + } else { + out[i] = 0; + } } - } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Neg(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = -in[i]; - } +void Neg(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = -in[i]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void LogicalNot(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = !in[i]; - } +void LogicalNot(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = !in[i]; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void OnesLike(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = static_cast(1); - } +void OnesLike(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(1); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ZerosLike(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = static_cast(0); - } +void ZerosLike(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(0); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Floor(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = static_cast(floor(in[i])); - } +void Floor(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(floor(in[i])); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Reciprocal(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = static_cast(1.0 / in[i]); - } +void Reciprocal(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(1.0 / in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Gelu(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - T x = in[i]; - auto double_x = static_cast(x); - T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); - out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; - } +void Gelu(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T x = in[i]; + auto double_x = static_cast(x); + T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); + out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Asin(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = asin(in[i]); - } +void Asin(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = asin(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void ACos(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = acos(in[i]); - } +void ACos(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = acos(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Atan(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = atan(in[i]); - } +void Atan(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = atan(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Sin(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = sin(in[i]); - } +void Sin(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = sin(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Cos(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = cos(in[i]); - } +void Cos(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = cos(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Tan(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = tan(in[i]); - } +void Tan(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = tan(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Sinh(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = sinh(in[i]); - } +void Sinh(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = sinh(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Cosh(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = cosh(in[i]); - } +void Cosh(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = cosh(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Asinh(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = asinh(in[i]); - } +void Asinh(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = asinh(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Acosh(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = acosh(in[i]); - } +void Acosh(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = acosh(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void Atanh(const T *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = atanh(in[i]); - } +void Atanh(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = atanh(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } } // namespace @@ -223,79 +283,31 @@ void ArithmeticSelfCPUKernel::LaunchKernelLogic(const std::vector &i T *input = reinterpret_cast(inputs[0]->addr); T *output = reinterpret_cast(outputs[0]->addr); size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - if (operate_type_ == LOGICALNOT) { - threads.emplace_back(std::thread(LogicalNot, input, output, start, end)); - } - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + LogicalNot(input, output, lens); + return; } template void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { - if (target_dtype_ == kNumberTypeBool) { - LaunchKernelLogic(inputs, outputs); - return; - } T *input = reinterpret_cast(inputs[0]->addr); T *output = reinterpret_cast(outputs[0]->addr); size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - static const std::map> - kArithmeticOpFuncMap = {{SQUARE, Square}, {SIGN, Sign}, - {NEG, Neg}, {LOGICALNOT, LogicalNot}, - {ONESLIKE, OnesLike}, {ZEROSLIKE, ZerosLike}, - {FLOOR, Floor}, {RECIPROCAL, Reciprocal}, - {GELU, Gelu}, {SIN, Sin}, - {COS, Cos}, {TAN, Tan}, - {ASIN, Asin}, {ACOS, ACos}, - {ATAN, Atan}, {SINH, Sinh}, - {COSH, Cosh}, {ASINH, Asinh}, - {ACOSH, Acosh}, {ATANH, Atanh}}; - - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); + static const std::map> kArithmeticOpFuncMap = { + {SQUARE, Square}, {SIGN, Sign}, + {NEG, Neg}, {LOGICALNOT, LogicalNot}, + {ONESLIKE, OnesLike}, {ZEROSLIKE, ZerosLike}, + {FLOOR, Floor}, {RECIPROCAL, Reciprocal}, + {GELU, Gelu}, {SIN, Sin}, + {COS, Cos}, {TAN, Tan}, + {ASIN, Asin}, {ACOS, ACos}, + {ATAN, Atan}, {SINH, Sinh}, + {COSH, Cosh}, {ASINH, Asinh}, + {ACOSH, Acosh}, {ATANH, Atanh}}; + if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { + kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); + } else { + MS_LOG(EXCEPTION) << "Not support " << operate_type_; } } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 0a68b3722b..84b6e1d6b0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -34,7 +34,6 @@ class ArithmeticSelfCPUKernel : public CPUKernel { template void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - template void LaunchKernelLogic(const std::vector &inputs, const std::vector &outputs); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc index b0226fd40b..22574b98c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc @@ -16,220 +16,38 @@ #include #include #include -#include #include "backend/kernel_compiler/cpu/cast_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { namespace kernel { template -void Cast(const S *in, T *out, size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - out[i] = static_cast(in[i]); - } +void Cast(const S *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(in[i]); + } + }; + CPUKernelUtils::ParallelFor(task, size); } template -void LaunchCast(const std::vector &inputs, const std::vector &outputs) { - S *input = reinterpret_cast(inputs[0]->addr); - T *output = reinterpret_cast(outputs[0]->addr); - MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); - - size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - threads.emplace_back(std::thread(Cast, input, output, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } -} - -void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0); target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0); } -bool CastCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - using TypePair = - std::function &, const std::vector &)>; - std::map> mode_map; - mode_map[kNumberTypeBool][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeFloat16][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeFloat32][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeFloat64][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeFloat64][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeInt8][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeInt8][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeInt16][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeInt32][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeInt64][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeInt64][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeUInt8][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeUInt8][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeUInt16][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeUInt16][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeUInt32][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeUInt32][kNumberTypeBool] = LaunchCast; - - mode_map[kNumberTypeUInt64][kNumberTypeFloat16] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeInt8] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeInt16] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeInt64] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeUInt8] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeUInt16] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeUInt32] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeUInt64] = LaunchCast; - mode_map[kNumberTypeUInt64][kNumberTypeBool] = LaunchCast; +template +bool CastCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + S *input = reinterpret_cast(inputs[0]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); - mode_map[source_dtype][target_dtype](inputs, outputs); + size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; + Cast(input, output, lens); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h index 7c68e8f612..d4ea77f095 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h @@ -23,6 +23,7 @@ namespace mindspore { namespace kernel { +template class CastCPUKernel : public CPUKernel { public: CastCPUKernel() = default; @@ -38,161 +39,305 @@ class CastCPUKernel : public CPUKernel { TypeId target_dtype{kTypeUnknown}; }; -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); - -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel, + bool, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel, + bool, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel, + bool, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + bool, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + bool, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + bool, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + bool, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + bool, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + bool, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + bool, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + bool, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + bool, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, float16, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, float16, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, float16, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + float16, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), + CastCPUKernel, float16, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + CastCPUKernel, float16, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + CastCPUKernel, float16, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), + CastCPUKernel, float16, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, float16, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, float16, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, float16, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + float16, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, float, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, float, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, float, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + float, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), + CastCPUKernel, float, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + CastCPUKernel, float, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + CastCPUKernel, float, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), + CastCPUKernel, float, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, float, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, float, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, float, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + float, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, double, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, double, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, double, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + double, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), + CastCPUKernel, double, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + CastCPUKernel, double, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + CastCPUKernel, double, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), + CastCPUKernel, double, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, double, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, double, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, double, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + double, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel, + int8_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel, + int8_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel, + int8_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + int8_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + int8_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + int8_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + int8_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + int8_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + int8_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + int8_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + int8_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + int8_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, int16_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, int16_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, int16_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + int16_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + int16_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + int16_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + int16_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + int16_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + int16_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + int16_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + int16_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + int16_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, int32_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, int32_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, int32_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + int32_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + int32_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + int32_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + int32_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + int32_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + int32_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + int32_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + int32_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + int32_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, int64_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, int64_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, int64_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + int64_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + int64_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + int64_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + int64_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + int64_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + int64_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + int64_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + int64_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + int64_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, uint8_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, uint8_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, uint8_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + uint8_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + uint8_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + uint8_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + uint8_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + uint8_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, + uint8_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, + uint8_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, + uint8_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + uint8_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, uint16_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, uint16_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, uint16_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + uint16_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + uint16_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + uint16_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + uint16_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + uint16_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, uint16_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, uint16_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, uint16_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + uint16_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, uint32_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, uint32_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, uint32_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + uint32_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + uint32_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + uint32_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + uint32_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + uint32_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, uint32_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, uint32_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, uint32_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + uint32_t, bool); + +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), + CastCPUKernel, uint64_t, float16); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), + CastCPUKernel, uint64_t, float); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), + CastCPUKernel, uint64_t, double); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, + uint64_t, int8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, + uint64_t, int16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, + uint64_t, int32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, + uint64_t, int64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, + uint64_t, uint8_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), + CastCPUKernel, uint64_t, uint16_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), + CastCPUKernel, uint64_t, uint32_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + CastCPUKernel, uint64_t, uint64_t); +MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, + uint64_t, bool); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc index 353ee5d4bd..548b23a3a0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc @@ -14,8 +14,11 @@ * limitations under the License. */ +#include #include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" #include "runtime/device/cpu/cpu_device_address.h" +#include "common/thread_pool.h" namespace mindspore { namespace kernel { @@ -72,23 +75,43 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector &inputs, con auto y = reinterpret_cast(outputs[0]->addr); auto mean = reinterpret_cast(outputs[1]->addr); auto var = reinterpret_cast(outputs[2]->addr); - for (size_t i = 0; i < block_num_; ++i) { - T sum = (T)0.0; - T square_sum = (T)0.0; - for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { - sum += x[j]; - square_sum += x[j] * x[j]; - } - T block_mean = sum / block_size_; - T block_var = square_sum / block_size_ - block_mean * block_mean; - for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { - auto param_shift = j % param_num_; - y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast(block_var) + eps_) * gamma[param_shift] + - beta[param_shift]; + size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); + if (block_num_ < thread_num) { + thread_num = block_num_; + } + std::vector tasks; + tasks.reserve(thread_num); + auto task = [&](size_t start, size_t end) { + for (size_t c = 0; c < ceil(static_cast(block_num_) / thread_num); ++c) { + if (c * thread_num + start >= block_num_) { + continue; + } + size_t i = c * thread_num + start; + T sum = (T)0.0; + T square_sum = (T)0.0; + for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { + sum += x[j]; + square_sum += x[j] * x[j]; + } + T block_mean = sum / block_size_; + T block_var = square_sum / block_size_ - block_mean * block_mean; + for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { + auto param_shift = j % param_num_; + y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast(block_var) + eps_) * gamma[param_shift] + + beta[param_shift]; + } + mean[i] = block_mean; + var[i] = block_var; } - mean[i] = block_mean; - var[i] = block_var; + }; + for (size_t i = 0; i < thread_num; ++i) { + auto block = [&, i]() { + task(i, i + 1); + return common::SUCCESS; + }; + tasks.emplace_back(block); } + common::ThreadPool::GetInstance().SyncRun(tasks); } void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc index 63cefe0ab9..1b9f4c89c2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc @@ -15,7 +15,9 @@ */ #include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" #include "runtime/device/cpu/cpu_device_address.h" +#include "common/thread_pool.h" namespace mindspore { namespace kernel { @@ -73,41 +75,75 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector &inputs, auto dx = reinterpret_cast(outputs[0]->addr); auto dg = reinterpret_cast(outputs[1]->addr); auto db = reinterpret_cast(outputs[2]->addr); - - for (size_t i = 0; i < param_num_; ++i) { - T dgamma = (T)0.0; - T dbeta = (T)0.0; - for (size_t j = i; j < param_size_ * param_num_; j += param_num_) { - auto norm_shift = static_cast(j / block_size_); - dgamma += dy[j] * (T)std::pow(static_cast(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]); - dbeta += dy[j]; - } - dg[i] = dgamma; - db[i] = dbeta; - } - for (size_t i = 0; i < block_num_; ++i) { - T sum1 = (T)0.0; - T sum2 = (T)0.0; - T sum3 = (T)0.0; - for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { - auto param_shift = j % param_num_; - auto norm_shift = static_cast(j / block_size_); - auto dxm = x[j] - mean[norm_shift]; - auto dyg = dy[j] * gamma[param_shift]; - sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast(var[norm_shift]) + eps_, -1.5); - sum2 += dyg; - sum3 += (T)(-2.0) * dxm; + size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); + auto thread_num1 = param_num_ < thread_num ? param_num_ : thread_num; + std::vector tasks1; + tasks1.reserve(thread_num1); + auto thread_num2 = block_num_ < thread_num ? block_num_ : thread_num; + std::vector tasks2; + tasks2.reserve(thread_num2); + auto task1 = [&](size_t start, size_t end) { + for (size_t c = 0; c < ceil(static_cast(param_num_) / thread_num1); ++c) { + if (c * thread_num1 + start >= param_num_) { + continue; + } + size_t param_index = c * thread_num1 + start; + T dgamma = (T)0.0; + T dbeta = (T)0.0; + for (size_t j = param_index; j < param_size_ * param_num_; j += param_num_) { + auto norm_shift = static_cast(j / block_size_); + dgamma += dy[j] * (T)std::pow(static_cast(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[param_index] = dgamma; + db[param_index] = dbeta; } - for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { - auto param_shift = j % param_num_; - auto norm_shift = static_cast(j / block_size_); - auto var_sqrt = (T)std::pow(static_cast(var[norm_shift]) + eps_, -0.5); - auto dx1 = dy[j] * gamma[param_shift] * var_sqrt; - auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]); - auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_); - dx[j] = dx1 + dx2 + dx3; + }; + auto task2 = [&](size_t start, size_t end) { + for (size_t c = 0; c < ceil(static_cast(block_num_) / thread_num2); ++c) { + if (c * thread_num2 + start >= block_num_) { + continue; + } + size_t block_index = c * thread_num2 + start; + T sum1 = (T)0.0; + T sum2 = (T)0.0; + T sum3 = (T)0.0; + for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) { + auto param_shift = j % param_num_; + auto norm_shift = static_cast(j / block_size_); + auto dxm = x[j] - mean[norm_shift]; + auto dyg = dy[j] * gamma[param_shift]; + sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast(var[norm_shift]) + eps_, -1.5); + sum2 += dyg; + sum3 += (T)(-2.0) * dxm; + } + for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) { + auto param_shift = j % param_num_; + auto norm_shift = static_cast(j / block_size_); + auto var_sqrt = (T)std::pow(static_cast(var[norm_shift]) + eps_, -0.5); + auto dx1 = dy[j] * gamma[param_shift] * var_sqrt; + auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]); + auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_); + dx[j] = dx1 + dx2 + dx3; + } } + }; + for (size_t i = 0; i < thread_num1; ++i) { + auto block = [&, i]() { + task1(i, i + 1); + return common::SUCCESS; + }; + tasks1.emplace_back(block); + } + common::ThreadPool::GetInstance().SyncRun(tasks1); + for (size_t i = 0; i < thread_num2; ++i) { + auto block = [&, i]() { + task2(i, i + 1); + return common::SUCCESS; + }; + tasks2.emplace_back(block); } + common::ThreadPool::GetInstance().SyncRun(tasks2); } void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc index a6505b0927..d9afeb211c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc @@ -16,6 +16,7 @@ #include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" #include +#include #include "runtime/device/cpu/cpu_device_address.h" #include "common/thread_pool.h" @@ -78,17 +79,37 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector &in MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; return false; } - for (size_t i = 0; i < unit_num_; ++i) { - size_t j = i / input_dim1_; - size_t k = i % input_dim1_; + size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); + if (unit_num_ < thread_num) { + thread_num = unit_num_; + } + std::vector tasks; + tasks.reserve(thread_num); + auto task = [&](size_t start, size_t end) { + for (size_t c = 0; c < ceil(static_cast(unit_num_) / thread_num); ++c) { + if (c * thread_num + start >= unit_num_) { + continue; + } + size_t i = c * thread_num + start; + size_t j = i / input_dim1_; + size_t k = i % input_dim1_; - T index = indices_addr[j]; - if (index < 0 || index >= SizeToInt(output_dim0_)) { - continue; + T index = indices_addr[j]; + if (index < 0 || index >= SizeToInt(output_dim0_)) { + continue; + } + size_t output_index = index * output_dim1_ + k; + output_addr[output_index] += input_addr[i]; } - size_t output_index = index * output_dim1_ + k; - output_addr[output_index] += input_addr[i]; + }; + for (size_t t = 0; t < thread_num; ++t) { + auto block = [&, t]() { + task(t, t + 1); + return common::SUCCESS; + }; + tasks.emplace_back(block); } + common::ThreadPool::GetInstance().SyncRun(tasks); return true; } } // namespace kernel