!9448 fix cpu issues

From: @huaweib
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
pull/9448/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 421bf40287

@ -38,16 +38,15 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
for (size_t i = 0; i < weight_height.size(); ++i) {
auto wh = weight_height[i];
int re = wh % stride;
int pad_along;
if (re == 0) {
re = stride;
}
int pad = kernel_size[i] - re;
padding_l->emplace_back(pad / 2);
if (pad % 2 == 0) {
padding_r->emplace_back(pad / 2);
pad_along = std::max(SizeToInt(kernel_size[i]) - stride, 0);
} else {
padding_r->emplace_back(pad / 2 + 1);
pad_along = std::max(SizeToInt(kernel_size[i]) - re, 0);
}
int pad = pad_along / 2;
padding_l->emplace_back(pad);
padding_r->emplace_back(pad_along - pad);
}
} else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) {
MS_LOG(INFO) << "pad valid";

@ -257,7 +257,8 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU) {
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {

@ -21,7 +21,7 @@ Examples:
>>> import mindspore.ops as ops
"""
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr
@ -32,7 +32,7 @@ from .operations import *
from .functional import *
__primitive__ = [
"prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature"
"prim_attr_register", "Primitive", "PrimitiveWithInfer", "PrimitiveWithCheck", "signature"
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",

Loading…
Cancel
Save