|
|
|
@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con
|
|
|
|
|
return reduce_min;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) {
|
|
|
|
|
bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) {
|
|
|
|
|
if (dtype != kNumberTypeFloat32) {
|
|
|
|
|
MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!";
|
|
|
|
|
return false;
|
|
|
|
@ -84,7 +84,7 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v
|
|
|
|
|
for (size_t item = 0; item < shape.size(); ++item) {
|
|
|
|
|
if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) {
|
|
|
|
|
if (keep_dims) {
|
|
|
|
|
// If keep_dims is true, curretn dimesion set to 1
|
|
|
|
|
// If keep_dims is true, current dimension set to 1
|
|
|
|
|
shape_first.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
|
|
|
|
|
CheckCNodeInputSize(cnode, 2);
|
|
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
|
|
|
|
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) {
|
|
|
|
|
MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!";
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) {
|
|
|
|
|
MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
|
|
|
|
|
MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!";
|
|
|
|
|
auto axis_value = prim->GetAttr(kAttrAxis);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(axis_value);
|
|
|
|
|
if (!axis_value->isa<ValueSequeue>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis);
|
|
|
|
|
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims);
|
|
|
|
|
|
|
|
|
|
if (!NeedOptmize(dtype, shape, axis)) {
|
|
|
|
|
if (!NeedOptimize(dtype, shape, axis)) {
|
|
|
|
|
MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create reduce_min1
|
|
|
|
|
CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode);
|
|
|
|
|
std::vector<int> axis_fisrt = CalFirstAxis(shape, axis);
|
|
|
|
|
std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims);
|
|
|
|
|
std::vector<int> axis_first = CalFirstAxis(shape, axis);
|
|
|
|
|
std::vector<size_t> shape_first = GetInferShape(shape, axis_first, keep_dims);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1);
|
|
|
|
|
|
|
|
|
|
// Create reduce_min2
|
|
|
|
|
CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode);
|
|
|
|
|