|
|
|
@ -20,16 +20,13 @@
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <cfloat>
|
|
|
|
|
|
|
|
|
|
#include "abstract/abstract_value.h"
|
|
|
|
|
#include "ir/value.h"
|
|
|
|
|
#include "ir/tensor.h"
|
|
|
|
|
#include "ir/param_info.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
#include "utils/shape_utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
bool ValueToBool(const ValuePtr &v, bool *value) {
|
|
|
|
@ -37,13 +34,13 @@ bool ValueToBool(const ValuePtr &v, bool *value) {
|
|
|
|
|
if (v->isa<BoolImm>()) {
|
|
|
|
|
*value = v->cast<BoolImmPtr>()->value();
|
|
|
|
|
} else if (v->isa<Int32Imm>()) {
|
|
|
|
|
*value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
|
|
|
|
|
*value = v->cast<Int32ImmPtr>()->value() != 0;
|
|
|
|
|
} else if (v->isa<UInt32Imm>()) {
|
|
|
|
|
*value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
|
|
|
|
|
*value = v->cast<UInt32ImmPtr>()->value() != 0;
|
|
|
|
|
} else if (v->isa<FP32Imm>()) {
|
|
|
|
|
*value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
|
|
|
|
|
*value = v->cast<FP32ImmPtr>()->value() != 0;
|
|
|
|
|
} else if (v->isa<FP64Imm>()) {
|
|
|
|
|
*value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
|
|
|
|
|
*value = v->cast<FP64ImmPtr>()->value() != 0;
|
|
|
|
|
} else if (v->isa<tensor::Tensor>()) {
|
|
|
|
|
auto tensor = v->cast<tensor::TensorPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
@ -65,11 +62,11 @@ bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
|
|
|
|
|
auto tensor = v->cast<tensor::TensorPtr>();
|
|
|
|
|
(void)tensor->data_sync();
|
|
|
|
|
if (tensor->Dtype()->ToString() == "Int32") {
|
|
|
|
|
int32_t *tensor_data = static_cast<int32_t *>(tensor->data_c());
|
|
|
|
|
auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
|
|
|
|
|
auto vb = tensor_data[0];
|
|
|
|
|
*value = static_cast<int64_t>(vb);
|
|
|
|
|
} else if (tensor->Dtype()->ToString() == "Int64") {
|
|
|
|
|
int64_t *tensor_data = static_cast<int64_t *>(tensor->data_c());
|
|
|
|
|
auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
|
|
|
|
|
auto vb = tensor_data[0];
|
|
|
|
|
*value = vb;
|
|
|
|
|
} else {
|
|
|
|
@ -86,39 +83,19 @@ bool BaseRefToBool(const BaseRef &v, bool *value) {
|
|
|
|
|
return ValueToBool(utils::cast<ValuePtr>(v), value);
|
|
|
|
|
} else if (utils::isa<bool>(v)) {
|
|
|
|
|
auto vb = utils::cast<bool>(v);
|
|
|
|
|
if (vb == true) {
|
|
|
|
|
*value = true;
|
|
|
|
|
} else {
|
|
|
|
|
*value = false;
|
|
|
|
|
}
|
|
|
|
|
*value = vb;
|
|
|
|
|
} else if (utils::isa<int>(v)) {
|
|
|
|
|
auto vb = utils::cast<int>(v);
|
|
|
|
|
if (vb == 0) {
|
|
|
|
|
*value = false;
|
|
|
|
|
} else {
|
|
|
|
|
*value = true;
|
|
|
|
|
}
|
|
|
|
|
*value = vb != 0;
|
|
|
|
|
} else if (utils::isa<unsigned int>(v)) {
|
|
|
|
|
auto vb = utils::cast<unsigned int>(v);
|
|
|
|
|
if (vb == 0) {
|
|
|
|
|
*value = false;
|
|
|
|
|
} else {
|
|
|
|
|
*value = true;
|
|
|
|
|
}
|
|
|
|
|
*value = vb != 0;
|
|
|
|
|
} else if (utils::isa<float>(v)) {
|
|
|
|
|
auto vb = utils::cast<float>(v);
|
|
|
|
|
if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
|
|
|
|
|
*value = false;
|
|
|
|
|
} else {
|
|
|
|
|
*value = true;
|
|
|
|
|
}
|
|
|
|
|
*value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
|
|
|
|
|
} else if (utils::isa<double>(v)) {
|
|
|
|
|
auto vb = utils::cast<double>(v);
|
|
|
|
|
if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
|
|
|
|
|
*value = false;
|
|
|
|
|
} else {
|
|
|
|
|
*value = true;
|
|
|
|
|
}
|
|
|
|
|
*value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "value is not supported to cast to be bool";
|
|
|
|
|
return false;
|
|
|
|
@ -187,13 +164,13 @@ bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMap
|
|
|
|
|
return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph,
|
|
|
|
|
bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
|
|
|
|
|
NodeMapEquiv *const equiv_node) {
|
|
|
|
|
std::unordered_set<AnfNodePtr> done;
|
|
|
|
|
std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
|
|
|
|
|
|
|
|
|
|
todo.push(std::make_pair(root1, root2));
|
|
|
|
|
while (todo.size() > 0) {
|
|
|
|
|
while (!todo.empty()) {
|
|
|
|
|
AnfNodePtr node1 = todo.top().first;
|
|
|
|
|
if (done.count(node1) > 0) {
|
|
|
|
|
todo.pop();
|
|
|
|
@ -231,7 +208,7 @@ bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equ
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph,
|
|
|
|
|
bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
|
|
|
|
|
NodeMapEquiv *const equiv_node) {
|
|
|
|
|
auto fg1_fg2 = std::make_pair(fg1, fg2);
|
|
|
|
|
if (equiv_func_graph == nullptr) {
|
|
|
|
@ -267,23 +244,35 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
|
|
|
|
|
if (scalar == nullptr) {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
|
|
|
|
|
}
|
|
|
|
|
tensor::TensorPtr tensor = nullptr;
|
|
|
|
|
if (scalar->isa<FloatImm>()) {
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
|
|
|
|
|
} else if (scalar->isa<Int32Imm>()) {
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
|
|
|
|
|
} else if (scalar->isa<Int64Imm>()) {
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), kInt64);
|
|
|
|
|
} else if (scalar->isa<BoolImm>()) {
|
|
|
|
|
const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
|
|
|
|
|
} else {
|
|
|
|
|
auto type = scalar->type();
|
|
|
|
|
auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str;
|
|
|
|
|
TypePtr data_type = scalar->type();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data_type);
|
|
|
|
|
TypeId type_id = data_type->type_id();
|
|
|
|
|
switch (type_id) {
|
|
|
|
|
case kNumberTypeBool:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeInt16:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeInt32:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeInt64:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
|
|
|
|
|
case kNumberTypeUInt8:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeUInt16:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeUInt32:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
|
|
|
|
|
case kNumberTypeUInt64:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
|
|
|
|
|
case kNumberTypeFloat64:
|
|
|
|
|
return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid.";
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
|
return tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
|
|
|
|
@ -301,7 +290,7 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>();
|
|
|
|
|
auto tensor = value->cast<tensor::TensorPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
|
tensors->push_back(tensor);
|
|
|
|
|
}
|
|
|
|
|