|
|
|
@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
|
|
|
|
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
|
|
|
|
|
reg_exist = false;
|
|
|
|
|
}
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
|
|
|
|
|
reg_exist = false;
|
|
|
|
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
|
|
|
|
|
reg_exist = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (op_run_info->op_name == prim::kPrimGatherD->name()) {
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
// Gather op needs converting const input to attr on GPU device
|
|
|
|
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
|
|
|
|
reg_exist = false;
|
|
|
|
|