|
|
|
|
@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH;
|
|
|
|
|
using mindspore::schema::ActivationType_LEAKY_RELU;
|
|
|
|
|
using mindspore::schema::ActivationType_RELU;
|
|
|
|
|
using mindspore::schema::ActivationType_RELU6;
|
|
|
|
|
using mindspore::schema::ActivationType_SWISH;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Activation;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) {
|
|
|
|
|
int stride = UP_DIV(length, thread_count_);
|
|
|
|
|
int count = MSMIN(stride, length - stride * task_id);
|
|
|
|
|
|
|
|
|
|
auto error_code = RET_OK;
|
|
|
|
|
auto ret = RET_OK;
|
|
|
|
|
|
|
|
|
|
if (type_ == schema::ActivationType_RELU) {
|
|
|
|
|
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_RELU6) {
|
|
|
|
|
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
|
|
|
|
|
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
|
|
|
|
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
|
|
|
|
} else if (type_ == schema::ActivationType_SIGMOID) {
|
|
|
|
|
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_TANH) {
|
|
|
|
|
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_SWISH) {
|
|
|
|
|
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_HSWISH) {
|
|
|
|
|
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_HSIGMOID) {
|
|
|
|
|
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
|
|
|
|
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
|
|
|
|
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
|
|
|
|
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Activation type error";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Activation error, ret: " << ret;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ActivationRun(void *cdata, int task_id) {
|
|
|
|
|
|