|
|
@ -54,7 +54,7 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
|
|
|
TypeId mask_dtype = kNumberTypeUInt8;
|
|
|
|
TypeId mask_dtype = kNumberTypeUInt8;
|
|
|
|
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
|
|
|
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
|
|
|
if (mask_shape.size() != 4) {
|
|
|
|
if (mask_shape.size() != 4) {
|
|
|
|
MS_LOG(WARNING) << "relu's infer shape size not equal 4";
|
|
|
|
MS_LOG(DEBUG) << "relu's infer shape size not equal 4";
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0);
|
|
|
|
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0);
|
|
|
|