|
|
@ -37,7 +37,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
|
|
|
|
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
std::vector<size_t> label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
std::vector<size_t> label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
if (label_shape.size() > 1) {
|
|
|
|
if (label_shape.size() > 1) {
|
|
|
|
MS_LOG(EXCEPTION) << "label shape should be 1D";
|
|
|
|
MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
dnnl::memory::dims mem_dims;
|
|
|
|
dnnl::memory::dims mem_dims;
|
|
|
|
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());
|
|
|
|
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());
|
|
|
|