diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc index 1c35cdb617..03eae3afa1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc @@ -153,8 +153,8 @@ int TransposeFp16CPUKernel::Run() { fp16_out_data_ = reinterpret_cast(out_tensor->MutableData()); } - in_shape_ = const_cast(in_tensor->shape().data()); - out_shape_ = const_cast(out_tensor->shape().data()); + memcpy(in_shape_, in_tensor->shape().data(), in_tensor->shape().size() * sizeof(int)); + memcpy(out_shape_, out_tensor->shape().data(), out_tensor->shape().size() * sizeof(int)); ret = ParallelLaunch(this->context_->thread_pool_, TransposeFp16Run, this, thread_h_num_); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h index 29c959d4a6..79251b9486 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h @@ -48,8 +48,8 @@ class TransposeFp16CPUKernel : public LiteKernel { float *out_data_; float16_t *fp16_in_data_ = nullptr; float16_t *fp16_out_data_ = nullptr; - int *in_shape_; - int *out_shape_; + int in_shape_[8]; + int out_shape_[8]; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 61113f56fe..2921df8753 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -30,7 +30,10 @@ using mindspore::schema::PrimitiveType_Eltwise; namespace mindspore::kernel { -ArithmeticCPUKernel::~ArithmeticCPUKernel() {} +ArithmeticCPUKernel::~ArithmeticCPUKernel() { + FreeTmpPtr(); + return; +} int ArithmeticCPUKernel::Init() { if (!InferShapeDone()) { @@ -39,6 +42,59 @@ int ArithmeticCPUKernel::Init() { return ReSize(); } +int ArithmeticCPUKernel::InitBroadCastCase() { + /* if const node need broadcast + * and all need-broadcast-node are const + * broadcast in resize */ + + if (arithmeticParameter_->broadcasting_ == false) { + return RET_OK; + } + + int broadcast_size = out_tensors_[0]->Size(); + if (broadcast_size < 0) { + return RET_OK; + } + + if (arithmeticParameter_->in_elements_num0_ != arithmeticParameter_->out_elements_num_ && + arithmeticParameter_->in_elements_num1_ != arithmeticParameter_->out_elements_num_) { + /* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2] + * need broadcast both input */ + return RET_OK; + } + + FreeTmpPtr(); + + CalcMultiplesAndStrides(arithmeticParameter_); + + if (in_tensors_[0]->data_c() != nullptr && + arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) { + input0_ptr_ = malloc(broadcast_size); + if (input0_ptr_ == nullptr) { + return RET_ERROR; + } + TileOneDimension(reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(input0_ptr_), 0, + arithmeticParameter_->ndim_, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, + arithmeticParameter_->out_strides_, arithmeticParameter_->multiples0_); + arithmeticParameter_->broadcasting_ = false; + input0_broadcast_ = true; + } + if (in_tensors_[1]->data_c() != nullptr && + arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) { + input1_ptr_ = malloc(broadcast_size); + if (input1_ptr_ == nullptr) { + FreeTmpPtr(); + return RET_ERROR; + } + TileOneDimension(reinterpret_cast(in_tensors_[1]->data_c()), reinterpret_cast(input1_ptr_), 0, + arithmeticParameter_->ndim_, arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, + arithmeticParameter_->out_strides_, arithmeticParameter_->multiples1_); + arithmeticParameter_->broadcasting_ = false; + input1_broadcast_ = true; + } + return RET_OK; +} + int ArithmeticCPUKernel::PreProcess() { if (!InferShapeDone()) { (const_cast(primitive_))->set_infer_flag(true); @@ -73,25 +129,98 @@ int ArithmeticCPUKernel::PreProcess() { return RET_OK; } -int ArithmeticCPUKernel::ReSize() { - auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_; - arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting(); - arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims(); - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { - data_type_ = kDataTypeFloat; - } else { - data_type_ = kDataTypeInt; +void ArithmeticCPUKernel::InitRunFunction() { + switch (op_parameter_->type_) { + case PrimitiveType_Mul: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementMulRelu; + arithmetic_run_int_ = ElementMulReluInt; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementMulRelu6; + arithmetic_run_int_ = ElementMulRelu6Int; + break; + default: + arithmetic_run_ = ElementMul; + arithmetic_run_int_ = ElementMulInt; + break; + } + break; + case PrimitiveType_Add: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementAddRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementAddRelu6; + break; + default: + arithmetic_run_ = ElementAdd; + arithmetic_run_int_ = ElementAddInt; + break; + } + break; + case PrimitiveType_Sub: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementSubRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementSubRelu6; + break; + default: + arithmetic_run_ = ElementSub; + arithmetic_run_int_ = ElementSubInt; + break; + } + break; + case PrimitiveType_Div: + case PrimitiveType_RealDiv: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementDivRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementDivRelu6; + break; + default: + arithmetic_run_ = ElementDiv; + break; + } + break; + case PrimitiveType_LogicalAnd: + arithmetic_run_ = ElementLogicalAnd; + break; + case PrimitiveType_LogicalOr: + arithmetic_run_ = ElementLogicalOr; + break; + case PrimitiveType_Maximum: + arithmetic_run_ = ElementMaximum; + break; + case PrimitiveType_Minimum: + arithmetic_run_ = ElementMinimum; + break; + case PrimitiveType_FloorDiv: + arithmetic_run_ = ElementFloorDiv; + arithmetic_run_int_ = ElementFloorDivInt; + break; + case PrimitiveType_FloorMod: + arithmetic_run_ = ElementFloorMod; + arithmetic_run_int_ = ElementFloorModInt; + break; + case PrimitiveType_SquaredDifference: + arithmetic_run_ = ElementSquaredDifference; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; + arithmetic_run_ = nullptr; + break; } - arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); - arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); - arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), - reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); - memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), - reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); - memcpy(arithmeticParameter_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), - reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); + return; +} +void ArithmeticCPUKernel::InitOptRunFunction() { if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { switch (arithmeticParameter_->op_parameter_.type_) { case PrimitiveType_Mul: @@ -163,23 +292,45 @@ int ArithmeticCPUKernel::ReSize() { break; } break; - case PrimitiveType_Equal: - case PrimitiveType_Less: - case PrimitiveType_Greater: - case PrimitiveType_NotEqual: - case PrimitiveType_LessEqual: - case PrimitiveType_GreaterEqual: + default: arithmetic_opt_run_ = nullptr; arithmetic_opt_run_int_ = nullptr; break; - default: - break; } } else { arithmetic_opt_run_ = nullptr; arithmetic_opt_run_int_ = nullptr; } - return RET_OK; + return; +} + +void ArithmeticCPUKernel::InitParam() { + auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_; + arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting(); + arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims(); + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { + data_type_ = kDataTypeFloat; + } else { + data_type_ = kDataTypeInt; + } + + arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), + reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); + memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), + reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); + memcpy(arithmeticParameter_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), + reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); + + return; +} + +int ArithmeticCPUKernel::ReSize() { + InitParam(); + InitOptRunFunction(); + return InitBroadCastCase(); } int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, @@ -229,7 +380,8 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { } int error_code; - if (arithmeticParameter_->broadcasting_) { // need broadcast + if (arithmeticParameter_->broadcasting_) { + /* need broadcast in runtime */ stride = UP_DIV(outside_, thread_count_); int out_count = MSMIN(stride, outside_ - stride * task_id); if (out_count <= 0) { @@ -237,59 +389,57 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { } int out_thread_stride = stride * task_id; if (data_type_ == kDataTypeFloat) { - error_code = BroadcastRun(reinterpret_cast(in_tensors_[0]->data_c()), - reinterpret_cast(in_tensors_[1]->data_c()), + error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } else { - error_code = BroadcastRun(reinterpret_cast(in_tensors_[0]->data_c()), - reinterpret_cast(in_tensors_[1]->data_c()), + error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } + return error_code; + } - } else if (arithmetic_opt_run_ != nullptr) { // no broadcast, one of input is scalar + if (arithmetic_opt_run_ != nullptr) { + /* run opt function + * one of input is scalar */ if (arithmeticParameter_->in_elements_num0_ == 1) { if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_opt_run_(reinterpret_cast(in_tensors_[0]->data_c()), - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, - arithmeticParameter_); + error_code = arithmetic_opt_run_( + reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); } else { - error_code = arithmetic_opt_run_int_(reinterpret_cast(in_tensors_[0]->data_c()), - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, - count, arithmeticParameter_); + error_code = arithmetic_opt_run_int_( + reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); } } else if (arithmeticParameter_->in_elements_num1_ == 1) { if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_opt_run_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()), - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, - arithmeticParameter_); + error_code = arithmetic_opt_run_( + reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); } else { - error_code = arithmetic_opt_run_int_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()), - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, - count, arithmeticParameter_); + error_code = arithmetic_opt_run_int_( + reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); } } else { MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar"; return RET_ERROR; } - } else { // no broadcast, neither is scalar, two same shape - if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_run_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); - } else { - error_code = arithmetic_run_int_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); - } + + return error_code; } - if (error_code != RET_OK) { - return RET_ERROR; + + /* no broadcast in runtime */ + if (data_type_ == kDataTypeFloat) { + error_code = arithmetic_run_(reinterpret_cast(input0_ptr_) + stride * task_id, + reinterpret_cast(input1_ptr_) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); + } else { + error_code = arithmetic_run_int_(reinterpret_cast(input0_ptr_) + stride * task_id, + reinterpret_cast(input1_ptr_) + stride * task_id, + reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); } - return RET_OK; + return error_code; } int ArithmeticsRun(void *cdata, int task_id) { @@ -302,7 +452,22 @@ int ArithmeticsRun(void *cdata, int task_id) { return RET_OK; } -int ArithmeticCPUKernel::Run() { +void ArithmeticCPUKernel::FreeTmpPtr() { + if (input0_broadcast_ == true && input0_ptr_ != nullptr) { + free(input0_ptr_); + input0_ptr_ = nullptr; + input0_broadcast_ = false; + } + if (input1_broadcast_ == true && input1_ptr_ != nullptr) { + free(input1_ptr_); + input1_ptr_ = nullptr; + input0_broadcast_ = false; + } + return; +} + +void ArithmeticCPUKernel::InitParamInRunTime() { + /* after infershape */ if (arithmeticParameter_->broadcasting_) { outside_ = 1; for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { @@ -312,13 +477,24 @@ int ArithmeticCPUKernel::Run() { } outside_ *= arithmeticParameter_->out_shape_[i]; } - ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); } + ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); - int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_); + if (!input0_broadcast_) { + input0_ptr_ = in_tensors_[0]->data_c(); + } + if (!input1_broadcast_) { + input1_ptr_ = in_tensors_[1]->data_c(); + } + return; +} + +int ArithmeticCPUKernel::Run() { + InitParamInRunTime(); + int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]"; return RET_ERROR; @@ -370,5 +546,4 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, CpuArithmeticFp32Kern REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator) - } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 3908058fcf..fee65a2ba2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -56,111 +56,7 @@ class ArithmeticCPUKernel : public LiteKernel { const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { arithmeticParameter_ = reinterpret_cast(parameter); - switch (parameter->type_) { - case PrimitiveType_Mul: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementMulRelu; - arithmetic_run_int_ = ElementMulReluInt; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementMulRelu6; - arithmetic_run_int_ = ElementMulRelu6Int; - break; - default: - arithmetic_run_ = ElementMul; - arithmetic_run_int_ = ElementMulInt; - break; - } - break; - case PrimitiveType_Add: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementAddRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementAddRelu6; - break; - default: - arithmetic_run_ = ElementAdd; - arithmetic_run_int_ = ElementAddInt; - break; - } - break; - case PrimitiveType_Sub: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementSubRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementSubRelu6; - break; - default: - arithmetic_run_ = ElementSub; - arithmetic_run_int_ = ElementSubInt; - break; - } - break; - case PrimitiveType_Div: - case PrimitiveType_RealDiv: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementDivRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementDivRelu6; - break; - default: - arithmetic_run_ = ElementDiv; - break; - } - break; - case PrimitiveType_LogicalAnd: - arithmetic_run_ = ElementLogicalAnd; - break; - case PrimitiveType_LogicalOr: - arithmetic_run_ = ElementLogicalOr; - break; - case PrimitiveType_Maximum: - arithmetic_run_ = ElementMaximum; - break; - case PrimitiveType_Minimum: - arithmetic_run_ = ElementMinimum; - break; - case PrimitiveType_FloorDiv: - arithmetic_run_ = ElementFloorDiv; - arithmetic_run_int_ = ElementFloorDivInt; - break; - case PrimitiveType_FloorMod: - arithmetic_run_ = ElementFloorMod; - arithmetic_run_int_ = ElementFloorModInt; - break; - case PrimitiveType_Equal: - arithmetic_run_ = ElementEqual; - break; - case PrimitiveType_NotEqual: - arithmetic_run_ = ElementNotEqual; - break; - case PrimitiveType_Less: - arithmetic_run_ = ElementLess; - break; - case PrimitiveType_LessEqual: - arithmetic_run_ = ElementLessEqual; - break; - case PrimitiveType_Greater: - arithmetic_run_ = ElementGreater; - break; - case PrimitiveType_GreaterEqual: - arithmetic_run_ = ElementGreaterEqual; - break; - case PrimitiveType_SquaredDifference: - arithmetic_run_ = ElementSquaredDifference; - break; - default: - MS_LOG(ERROR) << "Error Operator type " << parameter->type_; - arithmetic_run_ = nullptr; - break; - } + InitRunFunction(); } ~ArithmeticCPUKernel() override; @@ -171,6 +67,20 @@ class ArithmeticCPUKernel : public LiteKernel { virtual int DoArithmetic(int task_id); virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); + private: + void InitRunFunction(); + void InitOptRunFunction(); + void InitParam(); + void FreeTmpPtr(); + int InitBroadCastCase(); + void InitParamInRunTime(); + + private: + bool input0_broadcast_ = false; + bool input1_broadcast_ = false; + void *input0_ptr_ = nullptr; + void *input1_ptr_ = nullptr; + protected: int break_pos_ = 0; int outside_ = 0; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc index 6796d2133b..8934414480 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc @@ -263,7 +263,6 @@ int MatmulCPUKernel::RunImpl(int task_id) { MS_ASSERT(cur_a_ptr_); MS_ASSERT(b); MS_ASSERT(c); - MS_ASSERT(bias); if (is_vector_a_) { MatVecMul(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, cur_oc); } else { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc index 60845e7d89..a09fd5ec83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc @@ -143,8 +143,8 @@ int TransposeInt8CPUKernel::Run() { in_ptr_ = reinterpret_cast(in_tensor->data_c()); out_ptr_ = reinterpret_cast(out_tensor->data_c()); - in_shape_ = in_dims.data(); - out_shape_ = out_dims.data(); + memcpy(in_shape_, in_dims.data(), in_dims.size() * sizeof(int)); + memcpy(out_shape_, out_dims.data(), out_dims.size() * sizeof(int)); int ret = MallocTmpBuf(); if (ret != RET_OK) { @@ -157,8 +157,6 @@ int TransposeInt8CPUKernel::Run() { } FreeTmpBuf(); - in_shape_ = nullptr; - out_shape_ = nullptr; return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h index 05d1306765..2f71cec757 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h @@ -48,8 +48,6 @@ class TransposeInt8CPUKernel : public LiteKernel { TransposeParameter *transpose_param_; int8_t *in_ptr_ = nullptr; int8_t *out_ptr_ = nullptr; - int *in_shape_ = nullptr; - int *out_shape_ = nullptr; int *dim_size_ = nullptr; int *position_ = nullptr; bool extra_dims_ = false; @@ -57,6 +55,8 @@ class TransposeInt8CPUKernel : public LiteKernel { int thread_h_stride_ = 0; int thread_h_num_ = 0; int num_unit_ = 0; + int in_shape_[8]; + int out_shape_[8]; }; } // namespace mindspore::kernel