fix embeddinglookup error on cpu pynative mode

pull/10010/head
chujinjin 4 years ago
parent 0c88e3f256
commit 98d6898438

@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
reg_exist = false;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
reg_exist = false;
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
reg_exist = false;
}
}
if (op_run_info->op_name == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
// Gather op needs converting const input to attr on GPU device
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
reg_exist = false;

Loading…
Cancel
Save