|
|
|
@ -20,39 +20,39 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Add(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) {
|
|
|
|
|
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
out[i] = input1[i] + (is_number ? *input2 : input2[i]);
|
|
|
|
|
out[i] = input1[i] + input2[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) {
|
|
|
|
|
void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t start, size_t end) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
out[i] = input1[i] - (is_number ? *input2 : input2[i]);
|
|
|
|
|
std::vector<size_t> idx;
|
|
|
|
|
GenIndex(i, &idx);
|
|
|
|
|
out[i] = input1[idx[0]] - input2[idx[1]];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) {
|
|
|
|
|
void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
out[i] = input1[i] * (is_number ? *input2 : input2[i]);
|
|
|
|
|
out[i] = input1[i] * input2[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Div(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) {
|
|
|
|
|
void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
auto div_number = is_number ? *input2 : input2[i];
|
|
|
|
|
auto div_number = input2[i];
|
|
|
|
|
if (div_number == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot divided by 0!";
|
|
|
|
|
}
|
|
|
|
|
out[i] = input1[i] / div_number;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
@ -67,21 +67,20 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
operate_type_ = DIV;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto shape0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
if (shape1.size() == 0) {
|
|
|
|
|
is_number_ = true;
|
|
|
|
|
} else {
|
|
|
|
|
is_number_ = false;
|
|
|
|
|
if (shape0.size() != shape1.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < shape0.size(); ++i) {
|
|
|
|
|
if (shape0[i] != shape1[i]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
|
size_t l = input_shape0_.size();
|
|
|
|
|
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
|
|
|
|
input_shape0_.insert(input_shape0_.begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
l = input_shape1_.size();
|
|
|
|
|
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
|
|
|
|
input_shape1_.insert(input_shape1_.begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_);
|
|
|
|
|
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
|
|
|
|
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
|
|
|
|
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
|
|
|
|
if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type";
|
|
|
|
@ -103,14 +102,43 @@ bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ArithmeticCPUKernel::GenIndex(size_t num, std::vector<size_t> *idx) {
|
|
|
|
|
std::vector<size_t> tmp;
|
|
|
|
|
for (size_t i = 0; i < output_shape_.size() - 1; ++i) {
|
|
|
|
|
if (output_element_num_[i] > num) {
|
|
|
|
|
tmp.push_back(0);
|
|
|
|
|
} else {
|
|
|
|
|
tmp.push_back(num / output_element_num_[i]);
|
|
|
|
|
num %= output_element_num_[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tmp.push_back(num);
|
|
|
|
|
size_t idx0 = 0;
|
|
|
|
|
size_t idx1 = 0;
|
|
|
|
|
for (size_t k = 0; k < tmp.size() - 1; ++k) {
|
|
|
|
|
if (input_shape0_[k] > 1) {
|
|
|
|
|
idx0 += tmp[k] * input_element_num0_[k];
|
|
|
|
|
}
|
|
|
|
|
if (input_shape1_[k] > 1) {
|
|
|
|
|
idx1 += tmp[k] * input_element_num1_[k];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (input_shape0_[tmp.size() - 1] > 1) {
|
|
|
|
|
idx0 += tmp[tmp.size() - 1];
|
|
|
|
|
}
|
|
|
|
|
if (input_shape1_[tmp.size() - 1] > 1) {
|
|
|
|
|
idx1 += tmp[tmp.size() - 1];
|
|
|
|
|
}
|
|
|
|
|
idx->push_back(idx0);
|
|
|
|
|
idx->push_back(idx1);
|
|
|
|
|
}
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
|
|
|
|
T *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
|
|
|
|
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
|
|
|
|
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
|
|
|
|
auto lens = inputs[0]->size / sizeof(T);
|
|
|
|
|
auto lens = outputs[0]->size / sizeof(T);
|
|
|
|
|
MS_LOG(INFO) << "lens=" << lens;
|
|
|
|
|
|
|
|
|
|
const size_t thread_num = 24;
|
|
|
|
|
std::vector<std::thread> threads;
|
|
|
|
|
threads.reserve(thread_num);
|
|
|
|
@ -119,13 +147,13 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
|
|
|
|
|
while (start < lens) {
|
|
|
|
|
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
|
|
|
|
|
if (operate_type_ == ADD) {
|
|
|
|
|
threads.emplace_back(std::thread(Add<T>, input1, input2, output, start, end, is_number_));
|
|
|
|
|
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Add<T>, this, input1, input2, output, start, end));
|
|
|
|
|
} else if (operate_type_ == SUB) {
|
|
|
|
|
threads.emplace_back(std::thread(Sub<T>, input1, input2, output, start, end, is_number_));
|
|
|
|
|
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Sub<T>, this, input1, input2, output, start, end));
|
|
|
|
|
} else if (operate_type_ == MUL) {
|
|
|
|
|
threads.emplace_back(std::thread(Mul<T>, input1, input2, output, start, end, is_number_));
|
|
|
|
|
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul<T>, this, input1, input2, output, start, end));
|
|
|
|
|
} else if (operate_type_ == DIV) {
|
|
|
|
|
threads.emplace_back(std::thread(Div<T>, input1, input2, output, start, end, is_number_));
|
|
|
|
|
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end));
|
|
|
|
|
}
|
|
|
|
|
start += once_compute_size;
|
|
|
|
|
}
|
|
|
|
|