!13354 [MSLITE][Develop] fix bug of cpu fp32 op: layernorm; add gelu fp32

From: @yangruoqi713
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/13354/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c03f962cb6

@ -248,3 +248,11 @@ int HardTanh(const float *src, int length, float *dst, float min_val, float max_
}
return NNACL_OK;
}
int Gelu(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float tanh_res = TanhOpt(sqrt(2 / M_PI) * (src[i] + 0.044715 * pow(src[i], 3)));
dst[i] = 0.5f * src[i] * (1 + tanh_res);
}
return NNACL_OK;
}

@ -40,6 +40,7 @@ int HSigmoid(const float *src, int length, float *dst);
int Swish(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
int Gelu(const float *src, int length, float *dst);
float TanhOpt(float src);
#ifdef __cplusplus

@ -64,9 +64,10 @@ void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data
}
}
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data,
LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id) {
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean,
float *out_deno, LayerNormParameter *param, size_t task_id) {
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL || out_mean == NULL ||
out_deno == NULL) {
return NNACL_NULL_PTR;
}
int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_);
@ -74,25 +75,22 @@ int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_
for (int i = task_id * step; i < thread_end; i++) {
const float *src_norm = src_data + i * param->norm_inner_size_;
float *dst_norm = dst_data + i * param->norm_inner_size_;
float mean = 0.0f;
float square_mean = 0.0f;
LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &mean, &square_mean);
const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_);
if ((out_mean != NULL) && (out_deno != NULL)) {
out_mean[i] = mean;
out_deno[i] = deno;
}
out_mean[i] = 0.0f;
out_deno[i] = 0.0f;
LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &out_mean[i], &out_deno[i]);
const float deno = 1 / sqrtf(out_deno[i] - out_mean[i] * out_mean[i] + param->epsilon_);
if (param->norm_outer_size_ <= param->params_outer_size_) {
for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) {
const float *src_param = src_norm + x * param->params_inner_size_;
float *dst_param = dst_norm + x * param->params_inner_size_;
LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, mean, deno);
LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, out_mean[i],
deno);
}
} else {
int x = i / param->params_outer_size_;
const float *gamma = gamma_data + x * param->norm_inner_size_;
const float *beta = beta_data + x * param->norm_inner_size_;
LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, mean, deno);
LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, out_mean[i], deno);
}
}
return NNACL_OK;

@ -23,8 +23,8 @@
extern "C" {
#endif
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data,
LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id);
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean,
float *out_deno, LayerNormParameter *param, size_t task_id);
#ifdef __cplusplus
}
#endif

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/softmax_grad.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
SoftmaxCrossEntropyParameter *softmax_cross_entropy_param_ =
reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
if (softmax_cross_entropy_param_ == nullptr) {
MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
return nullptr;
}
memset(softmax_cross_entropy_param_, 0, sizeof(SoftmaxCrossEntropyParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
softmax_cross_entropy_param_->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(softmax_cross_entropy_param_);
}
Registry SparseSoftmaxCrossEntropyWithLogitsParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

@ -38,7 +38,7 @@ int ActivationCPUKernel::Init() {
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID &&
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HSIGMOID &&
type_ != schema::ActivationType_HARD_TANH) {
type_ != schema::ActivationType_HARD_TANH && type_ != schema::ActivationType_GELU) {
MS_LOG(ERROR) << "Activation fp32 not support type: " << type_;
return RET_ERROR;
}
@ -78,6 +78,8 @@ int ActivationCPUKernel::DoActivation(int task_id) {
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HARD_TANH) {
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
} else if (type_ == schema::ActivationType_GELU) {
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else {
MS_LOG(ERROR) << "Activation type error";
return RET_ERROR;

@ -61,7 +61,7 @@ int LayerNormCPUKernel::ReSize() {
}
int LayerNormCPUKernel::DoLayerNorm(int thread_id) {
int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, param_, mean_data_, var_data_, thread_id);
int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, mean_data_, var_data_, param_, thread_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]";
return ret;
@ -85,11 +85,18 @@ int LayerNormCPUKernel::Run() {
gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c());
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
if (out_tensors_.size() >= 3) {
if (out_tensors_.size() == 3) {
mean_data_ = reinterpret_cast<float *>(out_tensors_.at(1)->data_c());
var_data_ = reinterpret_cast<float *>(out_tensors_.at(2)->data_c());
} else {
mean_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float)));
var_data_ = reinterpret_cast<float *>(context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float)));
}
ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_);
if (out_tensors_.size() != 3) {
context_->allocator->Free(mean_data_);
context_->allocator->Free(var_data_);
}
return ret;
}

Loading…
Cancel
Save