|
|
|
@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
|
|
|
new_node->set_scope(relu->scope());
|
|
|
|
|
|
|
|
|
|
// ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector
|
|
|
|
|
// ReluV2's 2rd output is mask whose data type is uint8
|
|
|
|
|
TypeId mask_dtype = kNumberTypeUInt8;
|
|
|
|
|
std::vector<size_t> mask_shape;
|
|
|
|
|
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
|
|
|
|
if (mask_shape.size() != 4) {
|
|
|
|
|
MS_LOG(WARNING) << "relu's infer shape size not equal 4";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0);
|
|
|
|
|
if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) {
|
|
|
|
|
mask_shape[1] = (mask_shape[1] + 31) / 32;
|
|
|
|
|
mask_shape.push_back(4);
|
|
|
|
|
} else {
|
|
|
|
|
mask_shape[1] = (mask_shape[1] + 15) / 16;
|
|
|
|
|
mask_shape.push_back(2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype};
|
|
|
|
|
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
|
|
|
|
@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|
|
|
|
MS_EXCEPTION_IF_NULL(relu);
|
|
|
|
|
|
|
|
|
|
auto relu_v2 = CreateReluV2(graph, relu);
|
|
|
|
|
if (relu_v2 == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> relu_v2_node_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);
|
|
|
|
|
|
|
|
|
|