!8806 GPU Stridedslice support `new_axis_mask`

From: @wilfchen
Reviewed-by: @limingqi107,@kisnwang
Signed-off-by: @kisnwang
pull/8806/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ecdaeebd43

@ -143,12 +143,22 @@ class StridedSliceGpuKernel : public GpuKernel {
}
}
auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str);
for (size_t l = 0; l < shrink_axis_mask.size(); l++) {
if (shrink_axis_mask[l]) {
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1;
strides_[l] = end_[l] > begin_[l] ? 1 : -1;
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}
auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}

@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel {
}
}
auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str);
for (size_t l = 0; l < shrink_axis_mask.size(); l++) {
if (shrink_axis_mask[l]) {
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1;
strides_[l] = end_[l] > begin_[l] ? 1 : -1;
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}
auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}

Loading…
Cancel
Save