!7781 fix cpu strideslice grad

Merge pull request !7781 from baihuawei/strideslice
pull/7781/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c8c7d7400c

@ -22,15 +22,7 @@ namespace kernel {
void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
begin_[i] = begin_[i] + output_shape_[i];
}
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto strides = prim->GetAttr(STRIDES);
@ -40,36 +32,20 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) {
MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
}
for (size_t i = 0; i < strides_.size(); ++i) {
if (strides_[i] < 0) {
strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0;
}
if (end_[i] < 0) {
end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0;
}
}
FormatArgs(true);
} else {
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) {
size_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
if (size_.size() != output_shape_.size() || begin_.size() != output_shape_.size()) {
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
}
for (size_t i = 0; i < sizes.size(); ++i) {
if (sizes[i] < 0) {
sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0;
}
strides_.emplace_back(1);
end_.emplace_back(begin_[i] + sizes[i]);
}
FormatArgs(false);
}
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceGradCPUKernel::ExpandAllMemberDims() {
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
auto output_len = output_shape_.size();
if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) {
@ -79,6 +55,15 @@ void SliceGradCPUKernel::ExpandAllMemberDims() {
end_.insert(end_.begin(), 1);
}
}
for (size_t i = 0; i < 4; ++i) {
if (SignOfStride(i)) {
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
if (ax < 0) {
ax = 0;
}
input_shape_.push_back(IntToSize(ax));
}
}
}
bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -92,40 +77,39 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
MS_LOG(ERROR) << "output buff memset fail. ret:" << ret;
return false;
}
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
int stride_signs[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1],
begin_[2] * output_element_num_[2]};
size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1],
strides_[2] * output_element_num_[2]};
auto in_n_offset = 0;
auto out_n_offset = out_start_offset[0];
for (int i = begin_[0]; i < end_[0];
for (int i = begin_[0]; stride_signs[0] * i < stride_signs[0] * end_[0];
i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) {
if (can_copy_memory[0]) {
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]);
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
continue;
}
auto in_c_offset = 0;
auto out_c_offset = out_start_offset[1];
for (int j = begin_[1]; j < end_[1];
for (int j = begin_[1]; stride_signs[1] * j < stride_signs[1] * end_[1];
j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) {
if (can_copy_memory[1]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
input_element_num_[1]);
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, input_element_num_[1],
1);
continue;
}
auto in_h_offset = 0;
auto out_h_offset = out_start_offset[2];
for (int k = begin_[2]; k < end_[2];
for (int k = begin_[2]; stride_signs[2] * k < stride_signs[2] * end_[2];
k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) {
if (can_copy_memory[2]) {
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]);
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
continue;
}
for (int m = begin_[3]; m < end_[3]; m += strides_[3]) {
for (int m = begin_[3]; stride_signs[3] * m < stride_signs[3] * end_[3]; m += strides_[3]) {
output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++;
}
}
@ -143,19 +127,26 @@ bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
return true;
}
int SliceGradCPUKernel::SignOfStride(size_t axis) const {
if (strides_[axis] > 0) {
return 1;
}
return -1;
}
void SliceGradCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
size_t copy_num) const {
size_t copy_num, int id) const {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto in_buff_size = inputs[0]->size;
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto out_buff_size = outputs[0]->size;
if ((in_offset + copy_num) * sizeof(float) > in_buff_size) {
MS_LOG(EXCEPTION) << "input memory out of bounds.";
MS_LOG(EXCEPTION) << id << "input memory out of bounds.";
}
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
MS_LOG(EXCEPTION) << "output memory out of bounds.";
MS_LOG(EXCEPTION) << id << "output memory out of bounds.";
}
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
@ -165,6 +156,43 @@ void SliceGradCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr>
}
}
void SliceGradCPUKernel::FormatArgs(bool stride) {
if (stride) {
for (size_t i = 0; i < strides_.size(); ++i) {
if (strides_[i] == 0) {
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
}
if (end_[i] == 0 && begin_[i] < 0) {
end_[i] = end_[i] + SizeToInt(output_shape_[i]);
}
if (end_[i] < 0) {
end_[i] = end_[i] + SizeToInt(output_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(output_shape_[i]);
}
if (end_[i] > SizeToInt(output_shape_[i])) {
end_[i] = SizeToInt(output_shape_[i]);
}
}
}
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
auto k = begin_[i] + SizeToInt(output_shape_[i]);
begin_[i] = k < 0 ? 0 : k;
}
if (begin_[i] > SizeToInt(output_shape_[i])) {
begin_[i] = SizeToInt(output_shape_[i]);
}
}
if (!stride) {
for (size_t i = 0; i < size_.size(); ++i) {
while (size_[i] < 0) {
size_[i] = size_[i] + SizeToInt(output_shape_[i]);
}
strides_.emplace_back(1);
end_.emplace_back(begin_[i] + size_[i]);
}
}
}
void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {

@ -35,12 +35,16 @@ class SliceGradCPUKernel : public CPUKernel {
private:
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const;
int SignOfStride(size_t axis) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
int id) const;
void CheckParam(const CNodePtr &kernel_node) const;
void FormatArgs(bool stride);
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<int> size_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_element_num_;
std::vector<size_t> output_shape_;

Loading…
Cancel
Save