|
|
|
@ -37,7 +37,7 @@ int ActivationFp16CPUKernel::Init() {
|
|
|
|
|
if (type_ != schema::ActivationType_RELU && type_ != schema::ActivationType_RELU6 &&
|
|
|
|
|
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID &&
|
|
|
|
|
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
|
|
|
|
|
type_ != schema::ActivationType_SWISH) {
|
|
|
|
|
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HARD_TANH) {
|
|
|
|
|
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -67,6 +67,9 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {
|
|
|
|
|
error_code = HSwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
|
|
|
|
} else if (type_ == schema::ActivationType_SWISH) {
|
|
|
|
|
error_code = SwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
|
|
|
|
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
|
|
|
|
error_code =
|
|
|
|
|
HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|