|
|
|
@ -41,31 +41,18 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
|
|
|
|
prim->set_activation_type(mindspore::ActivationType::SELU);
|
|
|
|
|
} else if (tf_op.op() == "Softplus") {
|
|
|
|
|
prim->set_activation_type(mindspore::ActivationType::SOFTPLUS);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*output_size = 1;
|
|
|
|
|
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "add op input failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return prim.release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op,
|
|
|
|
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
|
|
|
|
std::vector<std::string> *inputs, int *output_size) {
|
|
|
|
|
auto prim = std::make_unique<ops::LeakyRelu>();
|
|
|
|
|
|
|
|
|
|
} else if (tf_op.op() == "LeakyRelu") {
|
|
|
|
|
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
|
|
|
|
|
tensorflow::AttrValue attr_value;
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
|
|
|
|
|
MS_LOG(ERROR) << "The attribute alpha should be specified.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
prim->set_negative_slope(attr_value.f());
|
|
|
|
|
prim->set_alpha(attr_value.f());
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*output_size = 1;
|
|
|
|
|
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
|
|
|
|
@ -81,7 +68,7 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
|
|
|
|
|
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
|
|
|
|
|
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
|
|
|
|
|
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
|
|
|
|
|
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser());
|
|
|
|
|
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
|
|
|
|
|
TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser());
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|