|
|
|
@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
|
|
|
|
void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
|
|
|
|
const size_t type_number) {
|
|
|
|
|
*max_type_id = type_id;
|
|
|
|
|
*max_type = type;
|
|
|
|
|
*max_type_number = type_number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
|
|
|
|
|
const std::set<size_t> &write_indexs) {
|
|
|
|
|
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
|
|
|
|
|
TypeId *arg_type = nullptr) {
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
if (is_write) {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
|
|
|
|
} else {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
auto tensor_type = tensor->element()->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
|
|
|
*arg_type_id = tensor_type->type_id();
|
|
|
|
|
if (arg_type != nullptr) {
|
|
|
|
|
*arg_type = kObjectTypeTensorType;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractScalar>()) {
|
|
|
|
|
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
|
|
|
|
auto scalar_type = scalar->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scalar_type);
|
|
|
|
|
*arg_type_id = scalar_type->type_id();
|
|
|
|
|
if (arg_type != nullptr) {
|
|
|
|
|
*arg_type = kObjectTypeNumber;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
|
|
|
|
|
const std::set<size_t> &write_indices) {
|
|
|
|
|
TypeId max_type_id = kTypeUnknown;
|
|
|
|
|
TypeId max_type = kTypeUnknown;
|
|
|
|
|
size_t max_type_number = 0;
|
|
|
|
|
bool has_int8 = false;
|
|
|
|
|
for (const auto &index : indexs) {
|
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
TypeId arg_type = kTypeUnknown;
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[index];
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
auto is_write = (write_indexs.find(index) != write_indexs.end());
|
|
|
|
|
if (is_write) {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
|
|
|
|
} else {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
auto tensor_type = tensor->element()->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
|
|
|
arg_type_id = tensor_type->type_id();
|
|
|
|
|
arg_type = kObjectTypeTensorType;
|
|
|
|
|
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
|
|
|
|
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
|
|
|
|
auto scalar_type = scalar->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scalar_type);
|
|
|
|
|
arg_type_id = scalar_type->type_id();
|
|
|
|
|
arg_type = kObjectTypeNumber;
|
|
|
|
|
} else {
|
|
|
|
|
auto is_write = (write_indices.find(index) != write_indices.end());
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto it = type_map.find(arg_type_id);
|
|
|
|
@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
has_int8 = true;
|
|
|
|
|
}
|
|
|
|
|
if (max_type_id == kTypeUnknown) {
|
|
|
|
|
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (max_type == arg_type) {
|
|
|
|
|
if (it->second > max_type_number) {
|
|
|
|
|
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (arg_type == kObjectTypeTensorType) {
|
|
|
|
|
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
|
|
|
|
|
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
|
|
|
|
|
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list,
|
|
|
|
|
const std::set<size_t> &write_indexs) {
|
|
|
|
|
const std::set<size_t> &write_indices) {
|
|
|
|
|
// record index for signature.dtypes of the same type
|
|
|
|
|
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
|
|
|
auto it = type_indexs.find(dtypes[i]);
|
|
|
|
|
if (it == type_indexs.end()) {
|
|
|
|
|
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
|
|
|
|
auto it = type_indices.find(dtypes[i]);
|
|
|
|
|
if (it == type_indices.end()) {
|
|
|
|
|
(void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
|
|
|
|
} else {
|
|
|
|
|
it->second.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type;
|
|
|
|
|
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
|
|
|
|
for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
|
|
|
|
|
auto type = it->first;
|
|
|
|
|
auto indexs = it->second;
|
|
|
|
|
auto indices = it->second;
|
|
|
|
|
// If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
|
|
|
|
|
if (indexs.size() < 2) {
|
|
|
|
|
if (indices.size() < 2) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bool has_tensor = false;
|
|
|
|
|
for (const auto &index : indexs) {
|
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[index];
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
|
|
@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
|
|
|
|
|
}
|
|
|
|
|
return dst_type;
|
|
|
|
|
}
|
|
|
|
@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
|
|
|
|
|
|
|
|
|
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
|
|
|
|
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
|
|
|
|
|
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
|
|
|
|
[](const Signature &sig) { return sig.dtype; });
|
|
|
|
@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
|
|
|
|
|
// Identify which arg requires auto cast
|
|
|
|
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto rw_it = write_indexs.find(i);
|
|
|
|
|
auto is_write = (rw_it != write_indexs.end());
|
|
|
|
|
auto rw_it = write_indices.find(i);
|
|
|
|
|
auto is_write = (rw_it != write_indices.end());
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[i];
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
if (is_write) {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
|
|
|
|
} else {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
auto tensor_type = tensor->element()->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
|
|
|
arg_type_id = tensor_type->type_id();
|
|
|
|
|
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
|
|
|
|
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
|
|
|
|
auto scalar_type = scalar->BuildType();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scalar_type);
|
|
|
|
|
arg_type_id = scalar_type->type_id();
|
|
|
|
|
}
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[i];
|
|
|
|
|
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
|
|
|
|
auto it_map = type_map.find(arg_type_id);
|
|
|
|
|
if (it_map == type_map.end()) {
|
|
|
|
|
continue;
|
|
|
|
@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs;
|
|
|
|
|
std::set<size_t> write_indexs;
|
|
|
|
|
std::set<size_t> write_indices;
|
|
|
|
|
op_inputs.push_back(NewValueNode(function));
|
|
|
|
|
// Assume, the write input of op is always the first input. We check if any write op,
|
|
|
|
|
// and add cast op on other inputs to keep the same type with assigned parameter.
|
|
|
|
@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|
|
|
|
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
|
|
|
|
} else if (sig == SignatureEnumRW::kRWWrite) {
|
|
|
|
|
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
|
|
|
|
write_indexs.insert(i);
|
|
|
|
|
write_indices.insert(i);
|
|
|
|
|
}
|
|
|
|
|
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
|
|
|
|
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
|
|
|
@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|
|
|
|
}
|
|
|
|
|
// process default
|
|
|
|
|
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
|
|
|
|
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
|
|
|
|
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
|
|
|
|
|
return func_graph->NewCNode(op_inputs);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|