diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index dee678cf8d..bccd2f4af3 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -54,11 +54,8 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph } auto tensor = node_value->cast(); MS_EXCEPTION_IF_NULL(tensor); - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); - } - size_t type_size = sizeof(TypeIdToType(output_type_id)); + TypeId output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); DeviceAddressPtr address = nullptr; @@ -245,7 +242,7 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker if (tensor_address != nullptr && tensor_address != address) { tensor->data_sync(false); } - if (tensor->data_type() == address->type_id_) { + if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { address->ptr_ = tensor->data_c(); } else { ShapeVector data_shape = tensor->shape(); diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc index 029ed2e3f4..bdb2abc19b 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -210,7 +210,7 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, const std::vector &kernel_attrs, const std::vector &input_formats, const std::vector &input_types, const std::vector &input_not_cnode_indexes, const std::vector &infer_output_formats, const std::vector &infer_output_types, - bool strict) { + std::pair *matched, bool strict) { int max_type_matched_num = -1; int max_format_matched_num = -1; for (auto kernel_attr : kernel_attrs) { @@ -244,10 +244,13 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, } // All formats and data types matched if (max_type_matched_num == SizeToInt(input_types.size()) && - max_format_matched_num == SizeToInt(input_types.size()) && - output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) && - output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) { - return true; + max_format_matched_num == SizeToInt(input_types.size())) { + matched->first = true; + if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) && + output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) { + matched->second = true; + return true; + } } } return false; @@ -261,22 +264,23 @@ void SetKernelInfo(const CNodePtr &kernel_node) { std::vector infer_output_formats; std::vector infer_output_types; MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); - GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); - GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types); auto kernel_attrs = kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); if (kernel_attrs.empty()) { MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] is not support."; } + GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); + GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types); KernelAttr selected_kernel_attr; - bool matched = true; + std::pair matched = std::make_pair(false, false); if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, - input_not_cnode_indexes, infer_output_formats, infer_output_types, true)) { - matched = SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, - input_not_cnode_indexes, infer_output_formats, infer_output_types, false); + input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { + matched = std::make_pair(false, false); + SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, + infer_output_formats, infer_output_types, &matched, false); } - - if (selected_kernel_attr.GetInputSize() > 0 && (matched || input_types.size() == input_not_cnode_indexes.size())) { + if (selected_kernel_attr.GetInputSize() > 0 && + (matched.first || input_types.size() == input_not_cnode_indexes.size())) { MS_LOG(INFO) << "Input format and dtype is matched"; GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types); UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node); diff --git a/tests/st/ops/cpu/test_cpu_type.py b/tests/st/ops/cpu/test_cpu_type.py new file mode 100644 index 0000000000..55dfd5564c --- /dev/null +++ b/tests/st/ops/cpu/test_cpu_type.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.context as context +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add = P.BiasAdd() + self.bias_add1 = P.BiasAdd() + + def construct(self, x, b, c): + return self.bias_add1(self.bias_add(x, b), c) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_bias_add1(): + x = np.ones([2, 2]).astype(np.float16) + b = np.array([1, 1]).astype(np.float16) + c = np.array([1, 1]).astype(np.float16) + bias_add = Net() + output = bias_add(Tensor(x), Tensor(b), Tensor(c)) + expect_output = np.ones([2, 2]).astype(np.float16) * 3 + assert np.all(output.asnumpy() == expect_output) + + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.bias_add = P.BiasAdd() + self.mul = P.Mul() + + def construct(self, x, a, b): + p1 = self.bias_add(x, b) + p2 = self.bias_add(x, a) + p3 = self.mul(p1, p2) + return p3 + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.bias_add = P.BiasAdd() + self.bias_add1 = P.BiasAdd() + + def construct(self, x, b, c): + return self.bias_add1(self.bias_add(x, b), c) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_bias_add2(): + x = np.ones([2, 2]).astype(np.float32) + a = np.array([1, 1]).astype(np.float32) + b = np.array([1, 1]).astype(np.float32) + c = np.array([1, 1]).astype(np.float32) + bias_add = Net1() + output = bias_add(Tensor(x), Tensor(a), Tensor(b)) + print(output) + + net2 = Net2() + output2 = net2(Tensor(x), Tensor(b), Tensor(c)) + print(output2) + + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class MomentumNet(nn.Cell): + def __init__(self): + super(MomentumNet, self).__init__() + self.batch_size = 1 + + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_momentum(): + epoch = 1 + net = MomentumNet() + learning_rate = (0.1, 0.2) + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss) + print("================================") + print(losses) + + return losses