!7621 add lite ops of instanceNorm

Merge pull request !7621 from liuwenhao/master
pull/7621/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 33d2cae607

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/instance_norm.h"
#include <math.h>
#include "nnacl/instance_norm_parameter.h"
#include "nnacl/op_base.h"
void InstanceNormFp32(const void *input, const void *mean, const void *variance, InstanceNormParameter *param,
int task_id, void *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread;
if (completed_units >= param->unit_) {
return;
}
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
int cur_offset = completed_units * param->channel_;
for (int n = 0; n < param->batch_; n++) {
for (int hw = 0; hw < cur_unit; hw++) {
for (int c = 0; c < param->channel_; c++) {
float variance_sqrt = sqrt(((const float *)variance)[n * param->channel_ + c] + param->epsilon_);
((float *)output)[cur_offset + c] =
(((const float *)input)[cur_offset + c] - ((const float *)mean)[n * param->channel_ + c]) / variance_sqrt;
}
cur_offset += param->channel_;
}
cur_offset += (param->unit_ - cur_unit) * param->channel_;
}
}

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_INSTANCE_NORM_H_
#define MINDSPORE_LITE_NNACL_FP32_INSTANCE_NORM_H_
#include "nnacl/instance_norm_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void InstanceNormFp32(const void *input, const void *mean, const void *variance, InstanceNormParameter *param,
int task_id, void *output);
void FusedInstanceNormFp32(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, InstanceNormParameter *param, int task_id, void *output);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_INSTANCE_NORM_H_

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_INSTANCE_NORM_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_INSTANCE_NORM_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct InstanceNormParameter {
OpParameter op_parameter_;
float epsilon_;
float momentum_;
int unit_;
int batch_;
int channel_;
bool fused_;
} InstanceNormParameter;
#endif // MINDSPORE_LITE_NNACL_INSTANCE_NORM_PARAMETER_H_

@ -0,0 +1,65 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/instance_norm.h"
#include <memory>
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float InstanceNorm::GetEpsilon() const { return this->primitive_->value.AsInstanceNorm()->epsilon; }
void InstanceNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsInstanceNorm()->epsilon = epsilon; }
int InstanceNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_InstanceNorm;
}
if (this->primitive_->value.type != schema::PrimitiveType_InstanceNorm) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::InstanceNormT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new InstanceNormT failed";
delete this->primitive_;
return RET_ERROR;
}
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateInstanceNorm(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_InstanceNorm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); }
#endif
} // namespace lite
} // namespace mindspore

@ -0,0 +1,45 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_INSTANE_NORM_H_
#define LITE_MINDSPORE_LITE_C_OPS_INSTANE_NORM_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class InstanceNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(InstanceNorm, PrimitiveC);
InstanceNorm() = default;
explicit InstanceNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetEpsilon(float epsilon);
#else
InstanceNorm() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEpsilon() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_INSTANE_NORM_H_

@ -31,6 +31,7 @@
#include "src/ops/batch_to_space.h"
#include "src/ops/prior_box.h"
#include "src/ops/lstm.h"
#include "src/ops/instance_norm.h"
#include "src/ops/softmax.h"
#include "src/ops/activation.h"
#include "src/ops/deconv2d.h"
@ -140,6 +141,7 @@
#include "nnacl/matmul_parameter.h"
#include "nnacl/fp32/roi_pooling.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/instance_norm_parameter.h"
#include "nnacl/fp32/tile.h"
#include "nnacl/fp32/topk.h"
#include "nnacl/reduce_parameter.h"
@ -219,6 +221,22 @@ OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
return reinterpret_cast<OpParameter *>(batch_norm_param);
}
OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) {
const auto param =
reinterpret_cast<mindspore::lite::InstanceNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
InstanceNormParameter *instance_norm_param =
reinterpret_cast<InstanceNormParameter *>(malloc(sizeof(InstanceNormParameter)));
if (instance_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc InstanceNormParameter failed.";
return nullptr;
}
memset(instance_norm_param, 0, sizeof(InstanceNormParameter));
instance_norm_param->op_parameter_.type_ = primitive->Type();
instance_norm_param->epsilon_ = param->GetEpsilon();
instance_norm_param->fused_ = false;
return reinterpret_cast<OpParameter *>(instance_norm_param);
}
OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) {
const auto param = reinterpret_cast<mindspore::lite::Fill *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
FillParameter *fill_param = reinterpret_cast<FillParameter *>(malloc(sizeof(FillParameter)));

@ -0,0 +1,93 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/instance_norm.h"
#include "nnacl/fp32/instance_norm.h"
#include "src/kernel_registry.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_InstanceNorm;
namespace mindspore::kernel {
int InstanceNormCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int InstanceNormCPUKernel::ReSize() {
auto input_shapes = in_tensors_[0]->shape();
auto n_dim = input_shapes.size();
auto param = reinterpret_cast<InstanceNormParameter *>(op_parameter_);
param->batch_ = input_shapes[0];
param->channel_ = input_shapes[n_dim - 1];
param->unit_ = 1;
for (size_t i = 1; i < n_dim - 1; i++) {
param->unit_ *= input_shapes[i];
}
return RET_OK;
}
int InstanceNormCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, InstanceNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
}
return ret;
}
int InstanceNormCPUKernel::DoExecute(int task_id) {
auto param = reinterpret_cast<InstanceNormParameter *>(op_parameter_);
InstanceNormFp32(in_tensors_.at(0)->MutableData(), in_tensors_.at(1)->MutableData(), in_tensors_.at(2)->MutableData(),
param, task_id, out_tensors_.at(0)->MutableData());
return mindspore::lite::RET_OK;
}
int InstanceNormRun(void *cdata, int task_id) {
auto kernel = reinterpret_cast<InstanceNormCPUKernel *>(cdata);
auto ret = kernel->DoExecute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InstanceNormRun error task_id[" << task_id << "] error_code[" << ret << "]";
}
return ret;
}
kernel::LiteKernel *CpuInstanceNormKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
auto *kernel = new (std::nothrow) InstanceNormCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new InstanceNormCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, CpuInstanceNormKernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INSTANCE_NORM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INSTANCE_NORM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "nnacl/instance_norm_parameter.h"
#include "src/runtime/runtime_api.h"
using mindspore::lite::InnerContext;
namespace mindspore::kernel {
class InstanceNormCPUKernel : public LiteKernel {
public:
InstanceNormCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~InstanceNormCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
virtual int DoExecute(int task_id);
};
int InstanceNormRun(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INSTANCE_NORM_H_

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/nnacl/fp32/instance_norm.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
class TestInstanceNormFp32 : public mindspore::CommonTest {
public:
TestInstanceNormFp32() {}
};
TEST_F(TestInstanceNormFp32, INTest1) {
std::vector<float> in_data = {-11.18675, 11.433986, 11.386012, 11.245945, -2.7614849, 14.692399,
-1.1983503, -6.6790967, 6.383416, -13.3213005, -8.693595, 9.476344};
std::vector<float> in_data1 = {12.352293, 5.122387, 14.249514};
std::vector<float> in_data2 = {14.632595, 0.70900035, 11.179003};
InstanceNormParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_InstanceNorm;
op_param.epsilon_ = 0.001f;
lite::Tensor input0_tensor(kNumberTypeFloat32, {1, 2, 2, 3});
lite::Tensor input1_tensor(kNumberTypeFloat32, {3});
lite::Tensor input2_tensor(kNumberTypeFloat32, {3});
input0_tensor.SetData(in_data.data());
input1_tensor.SetData(in_data1.data());
input2_tensor.SetData(in_data2.data());
std::vector<lite::Tensor *> inputs_tensor = {&input0_tensor, &input1_tensor, &input2_tensor};
std::vector<float> output(12);
std::vector<float> corr_out = {-6.1533737, 7.4904885, -0.8563998, -0.289212, -9.356432, 0.13245535,
-3.5422924, -14.005781, -2.3525476, -6.7113695, -16.396551, -1.4275324};
lite::Tensor output0_tensor(kNumberTypeFloat32, {1, 2, 2, 3});
output0_tensor.SetData(output.data());
std::vector<lite::Tensor *> outputs_tensor = {&output0_tensor};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_InstanceNorm};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
lite::InnerContext ctx;
ctx.thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, ctx.Init());
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
printf("==================output data=================\n");
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
std::cout << output[i] << " ,";
}
std::cout << std::endl;
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
input0_tensor.SetData(nullptr);
input1_tensor.SetData(nullptr);
input2_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
}
TEST_F(TestInstanceNormFp32, INTest2) {
std::vector<float> in_data = {-11.18675, 11.433986, 11.386012, 11.245945, -2.7614849, 14.692399,
-1.1983503, -6.6790967, 6.383416, -13.3213005, -8.693595, 9.476344,
-11.18675, 11.433986, 11.386012, 11.245945, -2.7614849, 14.692399,
-1.1983503, -6.6790967, 6.383416, -13.3213005, -8.693595, 9.476344};
std::vector<float> in_data1 = {12.352293, 5.122387, 14.249514, 12.352293, 5.122387, 14.249514};
std::vector<float> in_data2 = {14.632595, 0.70900035, 11.179003, 14.632595, 0.70900035, 11.179003};
InstanceNormParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_InstanceNorm;
op_param.epsilon_ = 0.001f;
lite::Tensor input0_tensor(kNumberTypeFloat32, {2, 2, 2, 3});
lite::Tensor input1_tensor(kNumberTypeFloat32, {6});
lite::Tensor input2_tensor(kNumberTypeFloat32, {6});
input0_tensor.SetData(in_data.data());
input1_tensor.SetData(in_data1.data());
input2_tensor.SetData(in_data2.data());
std::vector<lite::Tensor *> inputs_tensor = {&input0_tensor, &input1_tensor, &input2_tensor};
std::vector<float> output(24);
std::vector<float> corr_out = {-6.1533737, 7.4904885, -0.8563998, -0.289212, -9.356432, 0.13245535,
-3.5422924, -14.005781, -2.3525476, -6.7113695, -16.396551, -1.4275324,
-6.1533737, 7.4904885, -0.8563998, -0.289212, -9.356432, 0.13245535,
-3.5422924, -14.005781, -2.3525476, -6.7113695, -16.396551, -1.4275324};
lite::Tensor output0_tensor(kNumberTypeFloat32, {2, 2, 2, 3});
output0_tensor.SetData(output.data());
std::vector<lite::Tensor *> outputs_tensor = {&output0_tensor};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_InstanceNorm};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
lite::InnerContext ctx;
ctx.thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, ctx.Init());
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
printf("==================output data=================\n");
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
std::cout << output[i] << " ,";
}
std::cout << std::endl;
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
input0_tensor.SetData(nullptr);
input1_tensor.SetData(nullptr);
input2_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
}
} // namespace mindspore
Loading…
Cancel
Save