!13978 Reduce/Transpose/TensorAdd CPU kernel performance improve!

From: @yang_chun
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian,@c_34
pull/13978/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1d5e903771

@ -1,65 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
dnnl::memory::desc src0_desc;
dnnl::memory::desc src1_desc;
if (need_swap_) {
src0_desc = GetDefaultMemDesc(src1_shape);
src1_desc = GetDefaultMemDesc(src0_shape);
} else {
src0_desc = GetDefaultMemDesc(src0_shape);
src1_desc = GetDefaultMemDesc(src1_shape);
}
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, dst_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_desc);
AddArgument(DNNL_ARG_SRC_1, src1_desc);
AddArgument(DNNL_ARG_DST, dst_desc);
}
bool TensorAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 2 || outputs.empty()) {
MS_LOG(EXCEPTION) << "TensorAdd error input output size!";
}
if (need_swap_) {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[0]->addr);
} else {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
}
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
return true;
}
} // namespace kernel
} // namespace mindspore

@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include <string>
#include <functional>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
@ -33,15 +34,13 @@ class ReduceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void Transpose(const int size, const T *input, const std::vector<size_t> &input_shape,
const std::vector<size_t> &input_axis, const int shape_size, T *output);
void ConvertDataToOutput(const T *input, T *output);
void CheckAxis(const CNodePtr &kernel_node);
size_t reduce_type_ = 0;
std::vector<size_t> axis_;
std::vector<size_t> shape_;
size_t left_dims_ = 1;
size_t stride_ = 1;
void CheckParameter() const;
void CalculateTransposeInfo(std::vector<size_t> *new_shape, std::vector<size_t> *strides,
std::vector<size_t> *back_strides, size_t *stride) const;
std::vector<size_t> input_shape_;
std::vector<int64_t> axis_;
int reduce_type_{0};
std::function<void(const T *, size_t, T *)> reduce_func_;
};
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

@ -0,0 +1,150 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/tensoradd_cpu_kernel.h"
#include <vector>
namespace mindspore {
namespace kernel {
namespace {
struct Iterator {
std::vector<size_t> coordinates_;
std::vector<size_t> input_shape_a_;
std::vector<size_t> input_shape_b_;
std::vector<size_t> output_shape_;
std::vector<size_t> input_strides_a_;
std::vector<size_t> input_strides_b_;
int output_dimension_pos_{0};
size_t pos_{0};
Iterator(const std::vector<size_t> &input_shape_a, const std::vector<size_t> &input_shape_b,
const std::vector<size_t> &output_shape, const std::vector<size_t> &input_strides_a,
const std::vector<size_t> &input_strides_b, size_t pos)
: input_shape_a_(input_shape_a),
input_shape_b_(input_shape_b),
output_shape_(output_shape),
input_strides_a_(input_strides_a),
input_strides_b_(input_strides_b),
pos_{pos} {
output_dimension_pos_ = output_shape.size() - 1;
// Calculate coordinate with pos
coordinates_.resize(output_dimension_pos_ + 1);
int tmp = pos_;
for (int i = output_dimension_pos_; i >= 0 && tmp != 0; --i) {
coordinates_[i] = tmp % output_shape_[i];
tmp /= output_shape_[i];
}
}
void UpdateCoordinates() {
// Calculate output next coordinate
for (int i = output_dimension_pos_; i >= 0; --i) {
if (coordinates_[i] + 1 == output_shape_[i]) {
coordinates_[i] = 0;
} else {
++coordinates_[i];
break;
}
}
}
void GenPoints(std::array<size_t, 2> *position) {
auto &idx = *position;
idx = {0, 0};
for (int k = 0; k < output_dimension_pos_; ++k) {
if (input_shape_a_[k] > 1) {
idx[0] += coordinates_[k] * input_strides_a_[k];
}
if (input_shape_b_[k] > 1) {
idx[1] += coordinates_[k] * input_strides_b_[k];
}
}
if (input_shape_a_[output_dimension_pos_] > 1) {
idx[0] += coordinates_[output_dimension_pos_];
}
if (input_shape_b_[output_dimension_pos_] > 1) {
idx[1] += coordinates_[output_dimension_pos_];
}
}
};
} // namespace
void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
// Init shape ans strides
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_shape_b_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
}
bool TensorAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr_a = reinterpret_cast<float *>(inputs[0]->addr);
auto input_addr_b = reinterpret_cast<float *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto output_size = outputs[0]->size / sizeof(float);
if (input_shape_a_ == input_shape_b_) {
NormalProcess(input_addr_a, input_addr_b, output_addr, output_size);
} else { // Broadcast
BroadcastProcess(input_addr_a, input_addr_b, output_addr, output_size);
}
return true;
}
void TensorAddCPUKernel::NormalProcess(const float *input_a, const float *input_b, float *output, size_t size) {
auto task = [output, input_a, input_b](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
output[i] = input_a[i] + input_b[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
}
void TensorAddCPUKernel::BroadcastProcess(const float *input_a, const float *input_b, float *output, size_t size) {
// Broadcast shape
int dimension = output_shape_.size();
int input_dimension_a = input_shape_a_.size();
if (input_dimension_a < dimension) {
input_shape_a_.insert(input_shape_a_.begin(), dimension - input_dimension_a, 1);
}
int input_dimension_b = input_shape_b_.size();
if (input_dimension_b < dimension) {
input_shape_b_.insert(input_shape_b_.begin(), dimension - input_dimension_b, 1);
}
// Calculate strides
CalculateStrides(input_shape_a_, &input_strides_a_);
CalculateStrides(input_shape_b_, &input_strides_b_);
auto task = [this, input_a, input_b, output](size_t start, size_t end) {
Iterator iter(input_shape_a_, input_shape_b_, output_shape_, input_strides_a_, input_strides_b_, start);
std::array<size_t, 2> position{0};
for (size_t i = start; i < end; ++i) {
iter.GenPoints(&position);
output[i] = input_a[position[0]] + input_b[position[1]];
iter.UpdateCoordinates();
}
};
CPUKernelUtils::ParallelFor(task, size);
}
void TensorAddCPUKernel::CalculateStrides(const std::vector<size_t> &shape, std::vector<size_t> *strides) {
strides->resize(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
(*strides)[i] = shape[i + 1] * (*strides)[i + 1];
}
}
} // namespace kernel
} // namespace mindspore

@ -18,11 +18,12 @@
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class TensorAddCPUKernel : public MKLCPUKernel {
class TensorAddCPUKernel : public CPUKernel {
public:
TensorAddCPUKernel() = default;
~TensorAddCPUKernel() override = default;
@ -33,7 +34,15 @@ class TensorAddCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
bool need_swap_{false};
static void NormalProcess(const float *input_a, const float *input_b, float *output, size_t size);
void BroadcastProcess(const float *input_a, const float *input_b, float *output, size_t size);
static void CalculateStrides(const std::vector<size_t> &, std::vector<size_t> *);
std::vector<size_t> input_shape_a_;
std::vector<size_t> input_shape_b_;
// Define follow var for Broadcast
std::vector<size_t> output_shape_;
std::vector<size_t> input_strides_a_;
std::vector<size_t> input_strides_b_;
};
MS_REG_CPU_KERNEL(

@ -16,19 +16,22 @@
#include "backend/kernel_compiler/cpu/transpose_cpu_kernel.h"
#include <algorithm>
#include <vector>
#include <unordered_set>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
const size_t kMaxDim = 100;
namespace {
const size_t kMaxDim = 10;
}
void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<int64_t> axis_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
(void)std::transform(axis_me.begin(), axis_me.end(), std::back_inserter(axis_),
[](const int64_t &value) { return static_cast<int>(value); });
if (shape_.size() != axis_.size()) {
MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal.";
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
axes_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
CheckParameter();
dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
if (dtype_ == kTypeUnknown) {
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
@ -53,45 +56,84 @@ void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
}
}
bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
launch_func_(this, inputs, outputs);
return true;
}
void TransposeCPUFwdKernel::CheckParameter() const {
if (input_shape_.size() > kMaxDim) {
MS_LOG(EXCEPTION) << "Input tensor is " << input_shape_.size() << ", out of bound max dimension 10";
}
if (input_shape_.empty()) {
MS_LOG(EXCEPTION) << "Input tensor is empty";
}
if (input_shape_.size() != axes_.size()) {
MS_LOG(EXCEPTION) << "Input perm size is not equal with input shape";
}
// Input axes include the same axis
std::unordered_set<int64_t> unique_axes{axes_.begin(), axes_.end()};
if (unique_axes.size() != axes_.size()) {
MS_LOG(EXCEPTION) << "Input perm is illegal, it has the same axis";
}
// Input axes not in ture range(input_shape_.size())
int64_t shape_size = input_shape_.size();
for (auto &axis : axes_) {
if (axis < 0 || axis >= shape_size) {
MS_LOG(EXCEPTION) << "Input perm axis is out of bound input shape size";
}
}
}
template <typename T>
void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T));
size_t shape_size = IntToSize(shape_.size());
if (shape_size > kMaxDim) {
MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs.";
int dimension = input_shape_.size();
// Calculate input tensor strides
std::array<uint32_t, kMaxDim> input_strides{0};
input_strides[dimension - 1] = 1;
for (int i = dimension - 2; i >= 0; --i) {
input_strides[i] = input_shape_[i + 1] * input_strides[i + 1];
}
size_t pos_array[kMaxDim];
size_t size_offset[kMaxDim];
size_offset[0] = size / shape_[0];
for (size_t i = 1; i < shape_size; i++) {
size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i];
// Calculate output strides and back strides
std::array<uint32_t, kMaxDim> strides{0};
std::array<uint32_t, kMaxDim> back_strides{0};
for (int i = dimension - 1; i >= 0; --i) {
strides[i] = input_strides[axes_[i]];
back_strides[i] = (output_shape_[i] - 1) * strides[i];
}
for (size_t position = 0; position < size; position += 1) {
size_t temp_position = position;
pos_array[0] = temp_position / size_offset[0];
for (size_t i = 1; i < shape_size; i++) {
temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1];
pos_array[i] = temp_position / size_offset[i];
}
size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]];
size_t new_position_size = 1;
for (int j = shape_size - 2; j >= 0; j--) {
new_position_size *= shape_[axis_[j + 1]];
new_position += pos_array[axis_[j]] * new_position_size;
std::array<uint32_t, kMaxDim> coordinates{0};
auto get_next_pos = [&coordinates, &strides, &back_strides, &dimension, this](int curr_pos) {
for (int i = dimension - 1; i >= 0; --i) {
if (coordinates[i] + 1 == output_shape_[i]) {
coordinates[i] = 0;
curr_pos -= back_strides[i];
} else {
coordinates[i]++;
curr_pos += strides[i];
break;
}
}
output[new_position] = input[position];
}
}
return curr_pos;
};
bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
launch_func_(this, inputs, outputs);
return true;
auto input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T));
output[0] = input[0];
int pos = 0;
for (size_t i = 1; i < size; ++i) {
pos = get_next_pos(pos);
output[i] = input[pos];
}
}
} // namespace kernel
} // namespace mindspore

@ -33,12 +33,14 @@ class TransposeCPUFwdKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParameter() const;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
std::vector<size_t> shape_;
std::vector<int> axis_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
std::vector<int64_t> axes_;
TypeId dtype_{kTypeUnknown};
using TypeKernel =
std::function<void(TransposeCPUFwdKernel *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;

Loading…
Cancel
Save