fix bug of arithmetic compare op

pull/13366/head
fuzhiye 4 years ago
parent 1521430658
commit 5c42b5b62c

@ -31,7 +31,7 @@ namespace mindspore::kernel {
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
@ -44,7 +44,7 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
int error_code;
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,

Loading…
Cancel
Save