|
|
|
@ -86,23 +86,27 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const onnx::TensorProto *slope = ¶ms[0];
|
|
|
|
|
if (slope == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "input error: params[0] is null";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
|
|
|
|
|
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
|
|
|
|
|
if (slope_size == 1) {
|
|
|
|
|
attr->slope.push_back(*slope_raw_data);
|
|
|
|
|
attr->channelShared = true;
|
|
|
|
|
} else {
|
|
|
|
|
attr->slope.resize(slope_size);
|
|
|
|
|
attr->channelShared = false;
|
|
|
|
|
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed";
|
|
|
|
|
if (!params.empty()) {
|
|
|
|
|
const onnx::TensorProto *slope = ¶ms[0];
|
|
|
|
|
if (slope == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "input error: params[0] is null";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
|
|
|
|
|
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
|
|
|
|
|
if (slope_size == 1) {
|
|
|
|
|
attr->slope.push_back(*slope_raw_data);
|
|
|
|
|
attr->channelShared = true;
|
|
|
|
|
} else {
|
|
|
|
|
attr->slope.resize(slope_size);
|
|
|
|
|
attr->channelShared = false;
|
|
|
|
|
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_PReLU;
|
|
|
|
|