|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "ir/primitive.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "pipeline/static_analysis/abstract_value.h"
|
|
|
|
|
#include "pre_activate/common/helper.h"
|
|
|
|
@ -110,19 +111,24 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
|
|
|
|
|
auto strided_slice_grad = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strided_slice_grad);
|
|
|
|
|
|
|
|
|
|
if (!CheckAttrs(strided_slice_grad)) {
|
|
|
|
|
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
|
|
|
|
|
ValuePtrList strides_values;
|
|
|
|
|
if (!GetStridesValues(strided_slice_grad, &strides_values)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (ms_context->device_target() == kAscendDevice) {
|
|
|
|
|
if (!CheckAttrs(strided_slice_grad)) {
|
|
|
|
|
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!CheckValues(strides_values)) {
|
|
|
|
|
MS_LOG(INFO) << "Check strides' values failed, graph not changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
ValuePtrList strides_values;
|
|
|
|
|
if (!GetStridesValues(strided_slice_grad, &strides_values)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!CheckValues(strides_values)) {
|
|
|
|
|
MS_LOG(INFO) << "Check strides' values failed, graph not changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4});
|
|
|
|
|