|
|
|
@ -20,6 +20,7 @@
|
|
|
|
|
|
|
|
|
|
#include "abstract/abstract_value.h"
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "ir/dtype.h"
|
|
|
|
|
#include "abstract/dshape.h"
|
|
|
|
|
#include "abstract/param_validator.h"
|
|
|
|
|
#include "frontend/operator/cc_implementations.h"
|
|
|
|
@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
|
|
|
|
return empty;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
|
|
|
|
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
|
|
|
|
void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
|
|
|
|
|
bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
|
|
|
|
std::size_t sig_size = signature.size();
|
|
|
|
|
auto positional_size = sig_size;
|
|
|
|
|
if (has_var) {
|
|
|
|
|
positional_size = sig_size - 1;
|
|
|
|
|
}
|
|
|
|
|
if (args_spec_list.size() < positional_size) {
|
|
|
|
|
for (size_t i = args_spec_list.size(); i < sig_size; ++i) {
|
|
|
|
|
if (actual_param_number < positional_size) {
|
|
|
|
|
for (size_t i = actual_param_number; i < sig_size; ++i) {
|
|
|
|
|
auto default_value = signature[i].default_value;
|
|
|
|
|
if (default_value == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
|
|
|
|
@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
|
|
|
|
|
*max_type_number = type_number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
|
|
|
|
|
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
|
|
|
|
|
TypeId *arg_type = nullptr) {
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
|
|
|
|
|
arg_value = ref->ref();
|
|
|
|
|
if (!is_write && ref->need_cast()) {
|
|
|
|
|
auto tensor_type = ref->target_type();
|
|
|
|
|
*arg_type_id = tensor_type->type_id();
|
|
|
|
|
if (arg_type != nullptr) {
|
|
|
|
|
*arg_type = kObjectTypeTensorType;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
|
|
|
|
auto tensor_type = tensor->element()->BuildType();
|
|
|
|
|
if (arg_type_origin->isa<TensorType>()) {
|
|
|
|
|
auto tensor = arg_type_origin->cast<TensorTypePtr>();
|
|
|
|
|
auto tensor_type = tensor->element();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
|
|
|
*arg_type_id = tensor_type->type_id();
|
|
|
|
|
if (arg_type != nullptr) {
|
|
|
|
@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractScalar>()) {
|
|
|
|
|
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
|
|
|
|
auto scalar_type = scalar->BuildType();
|
|
|
|
|
if (arg_type_origin->isa<Number>()) {
|
|
|
|
|
auto scalar_type = arg_type_origin->cast<NumberPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scalar_type);
|
|
|
|
|
*arg_type_id = scalar_type->type_id();
|
|
|
|
|
if (arg_type != nullptr) {
|
|
|
|
@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
|
|
|
|
|
TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices,
|
|
|
|
|
const std::set<size_t> &write_indices) {
|
|
|
|
|
TypeId max_type_id = kTypeUnknown;
|
|
|
|
|
size_t max_type_number = 0;
|
|
|
|
@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
TypeId arg_type = kTypeUnknown;
|
|
|
|
|
auto is_write = (write_indices.find(index) != write_indices.end());
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (arg_type != kObjectTypeTensorType) {
|
|
|
|
@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
|
|
|
|
|
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
|
|
|
|
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
|
|
|
|
|
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) {
|
|
|
|
|
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types,
|
|
|
|
|
const std::set<size_t> &write_indices) {
|
|
|
|
|
// record index for signature.dtypes of the same type
|
|
|
|
|
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
|
|
|
|
@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
|
|
|
|
}
|
|
|
|
|
bool has_tensor = false;
|
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[index];
|
|
|
|
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
|
|
|
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
auto arg_value = input_types[index];
|
|
|
|
|
if (arg_value->isa<TensorType>()) {
|
|
|
|
|
has_tensor = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices)));
|
|
|
|
|
}
|
|
|
|
|
return dst_type;
|
|
|
|
|
}
|
|
|
|
@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
|
|
|
|
const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
|
|
|
|
|
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
|
|
|
@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices);
|
|
|
|
|
// Identify which arg requires auto cast
|
|
|
|
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
|
|
|
|
for (size_t i = 0; i < input_types.size(); ++i) {
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|
|
|
|
auto is_write = (rw_it != write_indices.end());
|
|
|
|
|
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
AbstractBasePtr arg_value = args_spec_list[i];
|
|
|
|
|
auto arg_value = input_types[i];
|
|
|
|
|
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
|
|
|
|
auto it_map = type_name_map.find(arg_type_id);
|
|
|
|
|
if (it_map == type_name_map.end()) {
|
|
|
|
@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
|
|
|
|
if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
|
|
|
@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs;
|
|
|
|
|
std::set<size_t> write_indices;
|
|
|
|
|
std::vector<TypePtr> input_types;
|
|
|
|
|
op_inputs.push_back(NewValueNode(function));
|
|
|
|
|
// Assume, the write input of op is always the first input. We check if any write op,
|
|
|
|
|
// and add cast op on other inputs to keep the same type with assigned parameter.
|
|
|
|
@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|
|
|
|
sig = signature[sig_size - 1].rw;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypePtr type = args_spec_list[i]->GetTypeTrack();
|
|
|
|
|
if (type && type->type_id() == kObjectTypeRef) {
|
|
|
|
|
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
|
|
|
|
|
TypePtr type = args_spec_list[i]->BuildType();
|
|
|
|
|
if (type && type->isa<RefType>()) {
|
|
|
|
|
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
|
|
|
|
|
if (sig == SignatureEnumRW::kRWRead) {
|
|
|
|
|
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
|
|
|
|
if (ref_abs && ref_abs->need_cast()) {
|
|
|
|
|
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
|
|
|
|
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
|
|
|
|
|
auto source_tensor_type = type->cast<TensorTypePtr>();
|
|
|
|
|
if (source_tensor_type != nullptr) {
|
|
|
|
|
auto source_element = source_tensor_type->element();
|
|
|
|
|
if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
|
|
|
|
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
|
|
|
|
param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph);
|
|
|
|
|
type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (sig == SignatureEnumRW::kRWWrite) {
|
|
|
|
|
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
|
|
|
|
write_indices.insert(i);
|
|
|
|
|
}
|
|
|
|
|
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
|
|
|
|
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
|
|
|
|
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
|
|
|
|
|
} else if (sig == SignatureEnumRW::kRWWrite &&
|
|
|
|
|
!((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
|
|
|
|
|
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
|
|
|
|
|
<< type->ToString();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
|
|
|
|
|
<< args_spec_list[i]->ToString();
|
|
|
|
|
input_types.push_back(type);
|
|
|
|
|
op_inputs.push_back(param);
|
|
|
|
|
}
|
|
|
|
|
// process default
|
|
|
|
|
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
|
|
|
|
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
|
|
|
|
|
ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
|
|
|
|
|
DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
|
|
|
|
|
return func_graph->NewCNode(op_inputs);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|