|
|
|
@ -31,7 +31,7 @@ namespace mindspore::kernel {
|
|
|
|
|
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
|
|
|
|
int out_thread_stride) {
|
|
|
|
|
if (dim > break_pos_) {
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
|
|
|
|
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
|
|
|
|
|
reinterpret_cast<int *>(input1) + out_thread_stride,
|
|
|
|
|
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
|
|
|
@ -44,7 +44,7 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o
|
|
|
|
|
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
|
|
|
|
|
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
|
|
|
|
|
int error_code;
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
|
|
|
|
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
|
|
|
|
|
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
|
|
|
|
|
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
|
|
|
|
|