|
|
|
@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
|
|
|
|
|
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str);
|
|
|
|
|
for (size_t l = 0; l < shrink_axis_mask.size(); l++) {
|
|
|
|
|
if (shrink_axis_mask[l]) {
|
|
|
|
|
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1;
|
|
|
|
|
strides_[l] = end_[l] > begin_[l] ? 1 : -1;
|
|
|
|
|
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
|
|
|
|
|
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
|
|
|
|
|
for (size_t l = 0; l < new_axis_mask.size(); l++) {
|
|
|
|
|
if (new_axis_mask[l]) {
|
|
|
|
|
begin_[l] = 0;
|
|
|
|
|
end_[l] = input_shape_[l];
|
|
|
|
|
strides_[l] = 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
|
|
|
|
|
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
|
|
|
|
|
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
|
|
|
|
|
if (shrink_axis_mask[m]) {
|
|
|
|
|
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
|
|
|
|
|
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|