add parallel for some CPU ops

pull/13022/head
zhaoting 4 years ago
parent fa4c19f938
commit c62baec9a4

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
#include <cmath>
#include <thread>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -25,7 +24,8 @@ namespace mindspore {
namespace kernel {
template <typename T>
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t start, size_t end) {
size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
@ -35,6 +35,8 @@ void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
}
}
};
CPUKernelUtils::ParallelFor(task, size);
}
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -84,31 +86,7 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return false;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return false;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon,
gradient, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
return true;
}
} // namespace kernel

@ -29,7 +29,7 @@ class AdamCPUKernel : public CPUKernel {
~AdamCPUKernel() override = default;
template <typename T>
void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t start, size_t end);
size_t size);
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

@ -40,43 +40,43 @@ class ArithmeticCPUKernel : public CPUKernel {
private:
void GenIndex(size_t num, std::vector<size_t> *tmp);
template <typename T>
void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Sub(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Add(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Add(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Mul(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end);
void RealDiv(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Div(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Div(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end);
void FloorDiv(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Mod(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Pow(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end);
void AssignAdd(T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Atan2(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Less(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Equal(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void NotEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end);
void SquaredDifference(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Greater(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LessEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LogicalOr(const T *input1, const T *input2, bool *out, size_t size);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;

@ -34,7 +34,6 @@ class ArithmeticSelfCPUKernel : public CPUKernel {
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

@ -16,220 +16,38 @@
#include <cmath>
#include <map>
#include <string>
#include <thread>
#include "backend/kernel_compiler/cpu/cast_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename S, typename T>
void Cast(const S *in, T *out, size_t start, size_t end) {
void Cast(const S *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}
template <typename S, typename T>
void LaunchCast(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
S *input = reinterpret_cast<S *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(Cast<S, T>, input, output, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
}
void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) {
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
}
bool CastCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
template <typename S, typename T>
bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using TypePair =
std::function<void(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
std::map<TypeId, std::map<TypeId, TypePair>> mode_map;
mode_map[kNumberTypeBool][kNumberTypeFloat16] = LaunchCast<bool, float16>;
mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast<bool, float>;
mode_map[kNumberTypeBool][kNumberTypeFloat64] = LaunchCast<bool, double>;
mode_map[kNumberTypeBool][kNumberTypeInt8] = LaunchCast<bool, int8_t>;
mode_map[kNumberTypeBool][kNumberTypeInt16] = LaunchCast<bool, int16_t>;
mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast<bool, int32_t>;
mode_map[kNumberTypeBool][kNumberTypeInt64] = LaunchCast<bool, int64_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt8] = LaunchCast<bool, uint8_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt16] = LaunchCast<bool, uint16_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt32] = LaunchCast<bool, uint32_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt64] = LaunchCast<bool, uint64_t>;
mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast<bool, bool>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat16] = LaunchCast<float16, float16>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast<float16, float>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast<float16, double>;
mode_map[kNumberTypeFloat16][kNumberTypeInt8] = LaunchCast<float16, int8_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt16] = LaunchCast<float16, int16_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt32] = LaunchCast<float16, int32_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt64] = LaunchCast<float16, int64_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt8] = LaunchCast<float16, uint8_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt16] = LaunchCast<float16, uint16_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt32] = LaunchCast<float16, uint32_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt64] = LaunchCast<float16, uint64_t>;
mode_map[kNumberTypeFloat16][kNumberTypeBool] = LaunchCast<float16, bool>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat16] = LaunchCast<float, float16>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast<float, float>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast<float, double>;
mode_map[kNumberTypeFloat32][kNumberTypeInt8] = LaunchCast<float, int8_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt16] = LaunchCast<float, int16_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast<float, int32_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt64] = LaunchCast<float, int64_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt8] = LaunchCast<float, uint8_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt16] = LaunchCast<float, uint16_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt32] = LaunchCast<float, uint32_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt64] = LaunchCast<float, uint64_t>;
mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast<float, bool>;
mode_map[kNumberTypeFloat64][kNumberTypeFloat16] = LaunchCast<double, float16>;
mode_map[kNumberTypeFloat64][kNumberTypeFloat32] = LaunchCast<double, float>;
mode_map[kNumberTypeFloat64][kNumberTypeFloat64] = LaunchCast<double, double>;
mode_map[kNumberTypeFloat64][kNumberTypeInt8] = LaunchCast<double, int8_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt16] = LaunchCast<double, int16_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt32] = LaunchCast<double, int32_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt64] = LaunchCast<double, int64_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt8] = LaunchCast<double, uint8_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt16] = LaunchCast<double, uint16_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt32] = LaunchCast<double, uint32_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt64] = LaunchCast<double, uint64_t>;
mode_map[kNumberTypeFloat64][kNumberTypeBool] = LaunchCast<double, bool>;
mode_map[kNumberTypeInt8][kNumberTypeFloat16] = LaunchCast<int8_t, float16>;
mode_map[kNumberTypeInt8][kNumberTypeFloat32] = LaunchCast<int8_t, float>;
mode_map[kNumberTypeInt8][kNumberTypeFloat64] = LaunchCast<int8_t, double>;
mode_map[kNumberTypeInt8][kNumberTypeInt8] = LaunchCast<int8_t, int8_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast<int8_t, int16_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast<int8_t, int32_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast<int8_t, int64_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt8] = LaunchCast<int8_t, uint8_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt16] = LaunchCast<int8_t, uint16_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt32] = LaunchCast<int8_t, uint32_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt64] = LaunchCast<int8_t, uint64_t>;
mode_map[kNumberTypeInt8][kNumberTypeBool] = LaunchCast<int8_t, bool>;
mode_map[kNumberTypeInt16][kNumberTypeFloat16] = LaunchCast<int16_t, float16>;
mode_map[kNumberTypeInt16][kNumberTypeFloat32] = LaunchCast<int16_t, float>;
mode_map[kNumberTypeInt16][kNumberTypeFloat64] = LaunchCast<int16_t, double>;
mode_map[kNumberTypeInt16][kNumberTypeInt8] = LaunchCast<int16_t, int8_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt16] = LaunchCast<int16_t, int16_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast<int16_t, int32_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast<int16_t, int64_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt8] = LaunchCast<int16_t, uint8_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt16] = LaunchCast<int16_t, uint16_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt32] = LaunchCast<int16_t, uint32_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt64] = LaunchCast<int16_t, uint64_t>;
mode_map[kNumberTypeInt16][kNumberTypeBool] = LaunchCast<int16_t, bool>;
mode_map[kNumberTypeInt32][kNumberTypeFloat16] = LaunchCast<int32_t, float16>;
mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast<int32_t, float>;
mode_map[kNumberTypeInt32][kNumberTypeFloat64] = LaunchCast<int32_t, double>;
mode_map[kNumberTypeInt32][kNumberTypeInt8] = LaunchCast<int32_t, int8_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt16] = LaunchCast<int32_t, int16_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast<int32_t, int32_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast<int32_t, int64_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt8] = LaunchCast<int32_t, uint8_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt16] = LaunchCast<int32_t, uint16_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt32] = LaunchCast<int32_t, uint32_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt64] = LaunchCast<int32_t, uint64_t>;
mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast<int32_t, bool>;
mode_map[kNumberTypeInt64][kNumberTypeFloat16] = LaunchCast<int64_t, float16>;
mode_map[kNumberTypeInt64][kNumberTypeFloat32] = LaunchCast<int64_t, float>;
mode_map[kNumberTypeInt64][kNumberTypeFloat64] = LaunchCast<int64_t, double>;
mode_map[kNumberTypeInt64][kNumberTypeInt8] = LaunchCast<int64_t, int8_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt16] = LaunchCast<int64_t, int16_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt32] = LaunchCast<int64_t, int32_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt64] = LaunchCast<int64_t, int64_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt8] = LaunchCast<int64_t, uint8_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt16] = LaunchCast<int64_t, uint16_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt32] = LaunchCast<int64_t, uint32_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt64] = LaunchCast<int64_t, uint64_t>;
mode_map[kNumberTypeInt64][kNumberTypeBool] = LaunchCast<int64_t, bool>;
mode_map[kNumberTypeUInt8][kNumberTypeFloat16] = LaunchCast<uint8_t, float16>;
mode_map[kNumberTypeUInt8][kNumberTypeFloat32] = LaunchCast<uint8_t, float>;
mode_map[kNumberTypeUInt8][kNumberTypeFloat64] = LaunchCast<uint8_t, double>;
mode_map[kNumberTypeUInt8][kNumberTypeInt8] = LaunchCast<uint8_t, int8_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast<uint8_t, int16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast<uint8_t, int32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast<uint8_t, int64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt8] = LaunchCast<uint8_t, uint8_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast<uint8_t, uint16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast<uint8_t, uint32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast<uint8_t, uint64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeBool] = LaunchCast<uint8_t, bool>;
mode_map[kNumberTypeUInt16][kNumberTypeFloat16] = LaunchCast<uint16_t, float16>;
mode_map[kNumberTypeUInt16][kNumberTypeFloat32] = LaunchCast<uint16_t, float>;
mode_map[kNumberTypeUInt16][kNumberTypeFloat64] = LaunchCast<uint16_t, double>;
mode_map[kNumberTypeUInt16][kNumberTypeInt8] = LaunchCast<uint16_t, int8_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt16] = LaunchCast<uint16_t, int16_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast<uint16_t, int32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast<uint16_t, int64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt8] = LaunchCast<uint16_t, uint8_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt16] = LaunchCast<uint16_t, uint16_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast<uint16_t, uint32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast<uint16_t, uint64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeBool] = LaunchCast<uint16_t, bool>;
mode_map[kNumberTypeUInt32][kNumberTypeFloat16] = LaunchCast<uint32_t, float16>;
mode_map[kNumberTypeUInt32][kNumberTypeFloat32] = LaunchCast<uint32_t, float>;
mode_map[kNumberTypeUInt32][kNumberTypeFloat64] = LaunchCast<uint32_t, double>;
mode_map[kNumberTypeUInt32][kNumberTypeInt8] = LaunchCast<uint32_t, int8_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt16] = LaunchCast<uint32_t, int16_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt32] = LaunchCast<uint32_t, int32_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast<uint32_t, int64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt8] = LaunchCast<uint32_t, uint8_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt16] = LaunchCast<uint32_t, uint16_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt32] = LaunchCast<uint32_t, uint32_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast<uint32_t, uint64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeBool] = LaunchCast<uint32_t, bool>;
mode_map[kNumberTypeUInt64][kNumberTypeFloat16] = LaunchCast<uint64_t, float16>;
mode_map[kNumberTypeUInt64][kNumberTypeFloat32] = LaunchCast<uint64_t, float>;
mode_map[kNumberTypeUInt64][kNumberTypeFloat64] = LaunchCast<uint64_t, double>;
mode_map[kNumberTypeUInt64][kNumberTypeInt8] = LaunchCast<uint64_t, int8_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt16] = LaunchCast<uint64_t, int16_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt32] = LaunchCast<uint64_t, int32_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt64] = LaunchCast<uint64_t, int64_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt8] = LaunchCast<uint64_t, uint8_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt16] = LaunchCast<uint64_t, uint16_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt32] = LaunchCast<uint64_t, uint32_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt64] = LaunchCast<uint64_t, uint64_t>;
mode_map[kNumberTypeUInt64][kNumberTypeBool] = LaunchCast<uint64_t, bool>;
S *input = reinterpret_cast<S *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
mode_map[source_dtype][target_dtype](inputs, outputs);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
Cast<S, T>(input, output, lens);
return true;
}
} // namespace kernel

@ -14,8 +14,11 @@
* limitations under the License.
*/
#include <cmath>
#include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
@ -72,7 +75,18 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
auto y = reinterpret_cast<T *>(outputs[0]->addr);
auto mean = reinterpret_cast<T *>(outputs[1]->addr);
auto var = reinterpret_cast<T *>(outputs[2]->addr);
for (size_t i = 0; i < block_num_; ++i) {
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (block_num_ < thread_num) {
thread_num = block_num_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num); ++c) {
if (c * thread_num + start >= block_num_) {
continue;
}
size_t i = c * thread_num + start;
T sum = (T)0.0;
T square_sum = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
@ -89,6 +103,15 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
mean[i] = block_mean;
var[i] = block_var;
}
};
for (size_t i = 0; i < thread_num; ++i) {
auto block = [&, i]() {
task(i, i + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) {

@ -15,7 +15,9 @@
*/
#include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
@ -73,23 +75,40 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
auto dg = reinterpret_cast<T *>(outputs[1]->addr);
auto db = reinterpret_cast<T *>(outputs[2]->addr);
for (size_t i = 0; i < param_num_; ++i) {
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
auto thread_num1 = param_num_ < thread_num ? param_num_ : thread_num;
std::vector<common::Task> tasks1;
tasks1.reserve(thread_num1);
auto thread_num2 = block_num_ < thread_num ? block_num_ : thread_num;
std::vector<common::Task> tasks2;
tasks2.reserve(thread_num2);
auto task1 = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(param_num_) / thread_num1); ++c) {
if (c * thread_num1 + start >= param_num_) {
continue;
}
size_t param_index = c * thread_num1 + start;
T dgamma = (T)0.0;
T dbeta = (T)0.0;
for (size_t j = i; j < param_size_ * param_num_; j += param_num_) {
for (size_t j = param_index; j < param_size_ * param_num_; j += param_num_) {
auto norm_shift = static_cast<int>(j / block_size_);
dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[i] = dgamma;
db[i] = dbeta;
dg[param_index] = dgamma;
db[param_index] = dbeta;
}
};
auto task2 = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num2); ++c) {
if (c * thread_num2 + start >= block_num_) {
continue;
}
for (size_t i = 0; i < block_num_; ++i) {
size_t block_index = c * thread_num2 + start;
T sum1 = (T)0.0;
T sum2 = (T)0.0;
T sum3 = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto dxm = x[j] - mean[norm_shift];
@ -98,7 +117,7 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
sum2 += dyg;
sum3 += (T)(-2.0) * dxm;
}
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5);
@ -108,6 +127,23 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
dx[j] = dx1 + dx2 + dx3;
}
}
};
for (size_t i = 0; i < thread_num1; ++i) {
auto block = [&, i]() {
task1(i, i + 1);
return common::SUCCESS;
};
tasks1.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks1);
for (size_t i = 0; i < thread_num2; ++i) {
auto block = [&, i]() {
task2(i, i + 1);
return common::SUCCESS;
};
tasks2.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks2);
}
void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {

@ -16,6 +16,7 @@
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
#include <string>
#include <cmath>
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
@ -78,7 +79,18 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret;
return false;
}
for (size_t i = 0; i < unit_num_; ++i) {
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (unit_num_ < thread_num) {
thread_num = unit_num_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) {
if (c * thread_num + start >= unit_num_) {
continue;
}
size_t i = c * thread_num + start;
size_t j = i / input_dim1_;
size_t k = i % input_dim1_;
@ -89,6 +101,15 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
size_t output_index = index * output_dim1_ + k;
output_addr[output_index] += input_addr[i];
}
};
for (size_t t = 0; t < thread_num; ++t) {
auto block = [&, t]() {
task(t, t + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel

Loading…
Cancel
Save