|
|
|
@ -29,7 +29,9 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_FusedBatchNorm;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
FusedBatchnormCPUKernel::~FusedBatchnormCPUKernel() {
|
|
|
|
|
FusedBatchnormCPUKernel::~FusedBatchnormCPUKernel() { FreeTmpBuffer(); }
|
|
|
|
|
|
|
|
|
|
void FusedBatchnormCPUKernel::FreeTmpBuffer() {
|
|
|
|
|
if (scale_addr_ != nullptr) {
|
|
|
|
|
free(scale_addr_);
|
|
|
|
|
scale_addr_ = nullptr;
|
|
|
|
@ -84,10 +86,14 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FusedBatchnormCPUKernel::Init() {
|
|
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
|
|
|
|
set_need_reinit();
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return ReSize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FusedBatchnormCPUKernel::ReSize() {
|
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|
auto input_shapes = in_tensors_[0]->shape();
|
|
|
|
|
auto n_dim = input_shapes.size();
|
|
|
|
|
batchnorm_param_->channel_ = input_shapes[n_dim - 1];
|
|
|
|
@ -106,15 +112,6 @@ int FusedBatchnormCPUKernel::Init() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FusedBatchnormCPUKernel::ReSize() {
|
|
|
|
|
auto input_shapes = in_tensors_[0]->shape();
|
|
|
|
|
batchnorm_param_->unit_ = 1;
|
|
|
|
|
for (int i = 0; i < input_shapes.size() - 1; i++) {
|
|
|
|
|
batchnorm_param_->unit_ *= input_shapes[i];
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int FusedBatchnormCPUKernel::Execute(int task_id) {
|
|
|
|
|
FusedBatchNorm(out_addr_, in_addr_, scale_addr_, offset_addr_, mean_addr_, var_addr_, task_id, batchnorm_param_);
|
|
|
|
|
return RET_OK;
|
|
|
|
@ -149,13 +146,16 @@ int FusedBatchnormCPUKernel::Run() {
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
|
|
|
|
OpParameter *opParameter, const lite::Context *ctx,
|
|
|
|
|
OpParameter *op_parameter, const lite::Context *ctx,
|
|
|
|
|
const kernel::KernelKey &desc,
|
|
|
|
|
const mindspore::lite::PrimitiveC *primitive) {
|
|
|
|
|
MS_ASSERT(opParameter != nullptr);
|
|
|
|
|
if (op_parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Input parameter is nullptr!";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_FusedBatchNorm);
|
|
|
|
|
FusedBatchnormCPUKernel *kernel =
|
|
|
|
|
new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
new (std::nothrow) FusedBatchnormCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new FusedBatchnormCPUKernel fail!";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -163,8 +163,8 @@ kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tenso
|
|
|
|
|
auto ret = kernel->Init();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
delete kernel;
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return kernel;
|
|
|
|
|