|
|
|
@ -28,17 +28,23 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_BatchNorm;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
int BatchnormCPUKernel::Init() { return RET_OK; }
|
|
|
|
|
int BatchnormCPUKernel::Init() {
|
|
|
|
|
auto input_shapes = inputs_[0]->shape();
|
|
|
|
|
auto n_dim = input_shapes.size();
|
|
|
|
|
batchnorm_param_->channel_ = input_shapes[n_dim - 1];
|
|
|
|
|
batchnorm_param_->unit_ = 1;
|
|
|
|
|
for (int i = 0; i < n_dim - 1; i++) {
|
|
|
|
|
batchnorm_param_->unit_ *= input_shapes[i];
|
|
|
|
|
}
|
|
|
|
|
batchnorm_param_->op_parameter_.thread_num_ =
|
|
|
|
|
MSMIN(batchnorm_param_->op_parameter_.thread_num_, batchnorm_param_->unit_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int BatchnormCPUKernel::ReSize() { return RET_OK; }
|
|
|
|
|
|
|
|
|
|
int BatchnormCPUKernel::DoExecute(int tid) {
|
|
|
|
|
int count = MSMIN(thread_unit_, units_ - tid * thread_unit_);
|
|
|
|
|
if (count <= 0) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
int offset = tid * thread_unit_ * channel_;
|
|
|
|
|
BatchNorm(in_addr_ + offset, mean_addr_, var_addr_, count, channel_, batchnorm_param_->epsilon_, out_addr_ + offset);
|
|
|
|
|
int BatchnormCPUKernel::DoExecute(int task_id) {
|
|
|
|
|
BatchNorm(out_addr_, in_addr_, mean_addr_, var_addr_, task_id, batchnorm_param_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -62,15 +68,8 @@ int BatchnormCPUKernel::Run() {
|
|
|
|
|
mean_addr_ = reinterpret_cast<float *>(inputs_.at(1)->Data());
|
|
|
|
|
var_addr_ = reinterpret_cast<float *>(inputs_.at(2)->Data());
|
|
|
|
|
out_addr_ = reinterpret_cast<float *>(outputs_.at(0)->Data());
|
|
|
|
|
auto input_shapes = inputs_[0]->shape();
|
|
|
|
|
channel_ = input_shapes[3];
|
|
|
|
|
units_ = 1;
|
|
|
|
|
for (int i = 0; i < 3; i++) {
|
|
|
|
|
units_ *= input_shapes[i];
|
|
|
|
|
}
|
|
|
|
|
thread_count_ = MSMIN(thread_count_, units_);
|
|
|
|
|
thread_unit_ = UP_DIV(units_, thread_count_);
|
|
|
|
|
int ret = LiteBackendParallelLaunch(BatchNormRun, this, thread_count_);
|
|
|
|
|
|
|
|
|
|
int ret = LiteBackendParallelLaunch(BatchNormRun, this, batchnorm_param_->op_parameter_.thread_num_);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "BatchnormRun error error_code[" << ret << "]";
|
|
|
|
|
return ret;
|
|
|
|
|