|  |  |  | @ -20,16 +20,13 @@ | 
			
		
	
		
			
				
					|  |  |  |  | #include <string> | 
			
		
	
		
			
				
					|  |  |  |  | #include <memory> | 
			
		
	
		
			
				
					|  |  |  |  | #include <algorithm> | 
			
		
	
		
			
				
					|  |  |  |  | #include <list> | 
			
		
	
		
			
				
					|  |  |  |  | #include <utility> | 
			
		
	
		
			
				
					|  |  |  |  | #include <cfloat> | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include "abstract/abstract_value.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "ir/value.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "ir/tensor.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "ir/param_info.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "utils/ms_context.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "utils/shape_utils.h" | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | namespace mindspore { | 
			
		
	
		
			
				
					|  |  |  |  | bool ValueToBool(const ValuePtr &v, bool *value) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -37,13 +34,13 @@ bool ValueToBool(const ValuePtr &v, bool *value) { | 
			
		
	
		
			
				
					|  |  |  |  |   if (v->isa<BoolImm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<BoolImmPtr>()->value(); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (v->isa<Int32Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true; | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<Int32ImmPtr>()->value() != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (v->isa<UInt32Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true; | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<UInt32ImmPtr>()->value() != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (v->isa<FP32Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true; | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<FP32ImmPtr>()->value() != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (v->isa<FP64Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true; | 
			
		
	
		
			
				
					|  |  |  |  |     *value = v->cast<FP64ImmPtr>()->value() != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (v->isa<tensor::Tensor>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor = v->cast<tensor::TensorPtr>(); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(tensor); | 
			
		
	
	
		
			
				
					|  |  |  | @ -65,11 +62,11 @@ bool BaseRefToInt(const ValuePtr &v, int64_t *value) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor = v->cast<tensor::TensorPtr>(); | 
			
		
	
		
			
				
					|  |  |  |  |     (void)tensor->data_sync(); | 
			
		
	
		
			
				
					|  |  |  |  |     if (tensor->Dtype()->ToString() == "Int32") { | 
			
		
	
		
			
				
					|  |  |  |  |       int32_t *tensor_data = static_cast<int32_t *>(tensor->data_c()); | 
			
		
	
		
			
				
					|  |  |  |  |       auto *tensor_data = static_cast<int32_t *>(tensor->data_c()); | 
			
		
	
		
			
				
					|  |  |  |  |       auto vb = tensor_data[0]; | 
			
		
	
		
			
				
					|  |  |  |  |       *value = static_cast<int64_t>(vb); | 
			
		
	
		
			
				
					|  |  |  |  |     } else if (tensor->Dtype()->ToString() == "Int64") { | 
			
		
	
		
			
				
					|  |  |  |  |       int64_t *tensor_data = static_cast<int64_t *>(tensor->data_c()); | 
			
		
	
		
			
				
					|  |  |  |  |       auto *tensor_data = static_cast<int64_t *>(tensor->data_c()); | 
			
		
	
		
			
				
					|  |  |  |  |       auto vb = tensor_data[0]; | 
			
		
	
		
			
				
					|  |  |  |  |       *value = vb; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
	
		
			
				
					|  |  |  | @ -86,39 +83,19 @@ bool BaseRefToBool(const BaseRef &v, bool *value) { | 
			
		
	
		
			
				
					|  |  |  |  |     return ValueToBool(utils::cast<ValuePtr>(v), value); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (utils::isa<bool>(v)) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto vb = utils::cast<bool>(v); | 
			
		
	
		
			
				
					|  |  |  |  |     if (vb == true) { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = true; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = false; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     *value = vb; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (utils::isa<int>(v)) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto vb = utils::cast<int>(v); | 
			
		
	
		
			
				
					|  |  |  |  |     if (vb == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = false; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = true; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     *value = vb != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (utils::isa<unsigned int>(v)) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto vb = utils::cast<unsigned int>(v); | 
			
		
	
		
			
				
					|  |  |  |  |     if (vb == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = false; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = true; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     *value = vb != 0; | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (utils::isa<float>(v)) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto vb = utils::cast<float>(v); | 
			
		
	
		
			
				
					|  |  |  |  |     if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = false; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = true; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (utils::isa<double>(v)) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto vb = utils::cast<double>(v); | 
			
		
	
		
			
				
					|  |  |  |  |     if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = false; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       *value = true; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON); | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     MS_LOG(DEBUG) << "value is not supported to cast to be bool"; | 
			
		
	
		
			
				
					|  |  |  |  |     return false; | 
			
		
	
	
		
			
				
					|  |  |  | @ -187,13 +164,13 @@ bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMap | 
			
		
	
		
			
				
					|  |  |  |  |   return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, | 
			
		
	
		
			
				
					|  |  |  |  | bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph, | 
			
		
	
		
			
				
					|  |  |  |  |                   NodeMapEquiv *const equiv_node) { | 
			
		
	
		
			
				
					|  |  |  |  |   std::unordered_set<AnfNodePtr> done; | 
			
		
	
		
			
				
					|  |  |  |  |   std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   todo.push(std::make_pair(root1, root2)); | 
			
		
	
		
			
				
					|  |  |  |  |   while (todo.size() > 0) { | 
			
		
	
		
			
				
					|  |  |  |  |   while (!todo.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |     AnfNodePtr node1 = todo.top().first; | 
			
		
	
		
			
				
					|  |  |  |  |     if (done.count(node1) > 0) { | 
			
		
	
		
			
				
					|  |  |  |  |       todo.pop(); | 
			
		
	
	
		
			
				
					|  |  |  | @ -231,7 +208,7 @@ bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equ | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | }  // namespace
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, | 
			
		
	
		
			
				
					|  |  |  |  | bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph, | 
			
		
	
		
			
				
					|  |  |  |  |                 NodeMapEquiv *const equiv_node) { | 
			
		
	
		
			
				
					|  |  |  |  |   auto fg1_fg2 = std::make_pair(fg1, fg2); | 
			
		
	
		
			
				
					|  |  |  |  |   if (equiv_func_graph == nullptr) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -267,23 +244,35 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { | 
			
		
	
		
			
				
					|  |  |  |  |   if (scalar == nullptr) { | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   tensor::TensorPtr tensor = nullptr; | 
			
		
	
		
			
				
					|  |  |  |  |   if (scalar->isa<FloatImm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (scalar->isa<Int32Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (scalar->isa<Int64Imm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     tensor = std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), kInt64); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (scalar->isa<BoolImm>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0; | 
			
		
	
		
			
				
					|  |  |  |  |     tensor = std::make_shared<tensor::Tensor>(bool_value, kBool); | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     auto type = scalar->type(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str; | 
			
		
	
		
			
				
					|  |  |  |  |   TypePtr data_type = scalar->type(); | 
			
		
	
		
			
				
					|  |  |  |  |   MS_EXCEPTION_IF_NULL(data_type); | 
			
		
	
		
			
				
					|  |  |  |  |   TypeId type_id = data_type->type_id(); | 
			
		
	
		
			
				
					|  |  |  |  |   switch (type_id) { | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeBool: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeInt8: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeInt16: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeInt32: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeInt64: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeUInt8: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeUInt16: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeUInt32: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeUInt64: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeFloat32: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     case kNumberTypeFloat64: | 
			
		
	
		
			
				
					|  |  |  |  |       return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type); | 
			
		
	
		
			
				
					|  |  |  |  |     default: | 
			
		
	
		
			
				
					|  |  |  |  |       MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid."; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   MS_EXCEPTION_IF_NULL(tensor); | 
			
		
	
		
			
				
					|  |  |  |  |   return tensor; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -301,7 +290,7 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> * | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (value->isa<tensor::Tensor>()) { | 
			
		
	
		
			
				
					|  |  |  |  |     tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor = value->cast<tensor::TensorPtr>(); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(tensor); | 
			
		
	
		
			
				
					|  |  |  |  |     tensors->push_back(tensor); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
	
		
			
				
					|  |  |  | 
 |