|
|
|
@ -288,6 +288,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dtypes);
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
|
bool has_sig_dtype = false;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
|
|
|
|
@ -733,20 +734,29 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
|
|
|
|
|
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
|
|
|
|
abstract::AbstractBasePtrList *args_spec_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_masks);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list);
|
|
|
|
|
CNodePtr cnode = nullptr;
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
|
|
|
|
|
auto prim = op_exec_info->py_primitive;
|
|
|
|
|
const auto &signature = prim->signatures();
|
|
|
|
|
|
|
|
|
|
inputs.push_back(NewValueNode(prim));
|
|
|
|
|
|
|
|
|
|
size_t size = op_exec_info->op_inputs.size();
|
|
|
|
|
auto sig_size = signature.size();
|
|
|
|
|
// ignore signature for cast op
|
|
|
|
|
if (sig_size > 0 && sig_size != size) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
|
|
|
|
|
<< "inputs size " << sig_size;
|
|
|
|
|
}
|
|
|
|
|
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
|
|
|
|
if (!is_cast_op) {
|
|
|
|
|
const auto &signature = prim->signatures();
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
auto obj = op_exec_info->op_inputs[i];
|
|
|
|
|
auto sig = SignatureEnumRW::kRWDefault;
|
|
|
|
|
if (signature.size() > 0) {
|
|
|
|
|
if (sig_size > 0) {
|
|
|
|
|
sig = signature[i].rw;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
|
|
|
|