fix floor_div_parser and random_standard_norm bug

pull/13152/head
zhaodezan 4 years ago
parent 2a2c7e7399
commit ae64426d6c

@ -488,6 +488,10 @@ schema::PrimitiveT *RangePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Range>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *RandomStandardNormalPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::RandomStandardNormal>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *RankPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Rank>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@ -843,6 +847,8 @@ RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFu
RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator);
RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator);
RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator);
RegistryMSOps g_RandomStandardNormalPrimitiveCreatorRegistry("RandomStandardNormal",
RandomStandardNormalPrimitiveCreator);
RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator);
RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator);
RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator);

@ -37,6 +37,7 @@
#include "ops/ceil.h"
#include "ops/fusion/exp_fusion.h"
#include "ops/floor.h"
#include "ops/floor_div.h"
#include "ops/floor_mod.h"
#include "ops/log.h"
#include "ops/sqrt.h"
@ -299,6 +300,20 @@ ops::PrimitiveC *TFFloorParser::Parse(const tensorflow::NodeDef &tf_op,
return prim.release();
}
ops::PrimitiveC *TFFloorDivParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::FloorDiv>();
*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}
return prim.release();
}
ops::PrimitiveC *TFFloorModParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
@ -435,6 +450,7 @@ TFNodeRegistrar g_tfSquareParser("Square", new TFSquareParser());
TFNodeRegistrar g_tfCeilParser("Ceil", new TFCeilParser());
TFNodeRegistrar g_tfExpParser("Exp", new TFExpParser());
TFNodeRegistrar g_tfFloorParser("Floor", new TFFloorParser());
TFNodeRegistrar g_tfFloorDivParser("FloorDiv", new TFFloorDivParser());
TFNodeRegistrar g_tfFloorModParser("FloorMod", new TFFloorModParser());
TFNodeRegistrar g_tfLogParser("Log", new TFLogParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFSqrtParser());

@ -204,6 +204,16 @@ class TFFloorParser : public TFNodeParser {
std::vector<std::string> *inputs, int *output_size) override;
};
class TFFloorDivParser : public TFNodeParser {
public:
TFFloorDivParser() = default;
~TFFloorDivParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
class TFFloorModParser : public TFNodeParser {
public:
TFFloorModParser() = default;

Loading…
Cancel
Save