adjust mindir model attr changes

pull/8686/head
xuanyue 4 years ago
parent 1783a3f6e2
commit c1ce164e42

@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
attr->axis = {1}; attr->axis = {1};
} else { } else {
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); attr->axis = CastToInt(prim.GetAttr("axis"), true);
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(WARNING) << "get axis failed"; MS_LOG(WARNING) << "get axis failed";
attr->axis = {0}; attr->axis = {0};
} else { } else {
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); attr->axis = CastToInt(prim.GetAttr("axis"), true);
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
auto prim_axis = GetValue<int>(prim.GetAttr("axis")); auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
attr->axis = prim_axis; attr->axis = prim_axis;
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -139,21 +139,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
@ -175,7 +175,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
int channel_mutiplier = 1; int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) { if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
} }
attr->channelMultiplier = channel_mutiplier; attr->channelMultiplier = channel_mutiplier;
@ -212,25 +212,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
@ -270,7 +270,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
int group = GetValue<int>(groupAttr); int group = CastToInt(groupAttr, false).front();
if (group > 1) { if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else { } else {

@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
attr->group = GetValue<int>(prim.GetAttr("group")); attr->group = CastToInt(prim.GetAttr("group"), false).front();
auto format = GetValue<std::string>(prim.GetAttr("data_format")); auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID; attr->padMode = schema::PadMode_VALID;

@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
attr->group = GetValue<int>(prim.GetAttr("group")); attr->group = CastToInt(prim.GetAttr("group"), false).front();
if (attr->group > 1) { if (attr->group > 1) {
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
} }
@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {

@ -132,21 +132,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
@ -168,7 +168,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
int channel_mutiplier = 1; int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) { if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
} }
attr->channelMultiplier = channel_mutiplier; attr->channelMultiplier = channel_mutiplier;
@ -195,25 +195,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[0]; attr->strideH = stride[0];
attr->strideW = stride[1]; attr->strideW = stride[1];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid" || pad_mode == "VALID") { if (pad_mode == "valid" || pad_mode == "VALID") {
@ -248,7 +248,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR; return RET_ERROR;
} }
int group = GetValue<int>(prim.GetAttr("group")); int group = CastToInt(prim.GetAttr("group"), false).front();
if (group == 1) { if (group == 1) {
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
} else if (group > 1) { } else if (group > 1) {

@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->format = schema::Format::Format_NUM_OF_FORMAT; attr->format = schema::Format::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pads")); auto pad_list = CastToInt(prim.GetAttr("pads"), true);
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); auto dilation = CastToInt(prim.GetAttr("dilation"), true);
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) { if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
} else { } else {
auto kernel_size = GetValue<int>(prim.GetAttr("kernel_size")); auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front();
attr->kernelH = kernel_size; attr->kernelH = kernel_size;
attr->kernelW = kernel_size; attr->kernelW = kernel_size;
} }
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); auto stride = CastToInt(prim.GetAttr("stride"), true);
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else { } else {
attr->activationType = schema::ActivationType_NO_ACTIVATION; attr->activationType = schema::ActivationType_NO_ACTIVATION;
} }
auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
attr->channelMultiplier = channel_multiplier; attr->channelMultiplier = channel_multiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterTwo); MS_ASSERT(inputs.size() == kAnfPopulaterTwo);

@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
// use axis instead of dim // use axis instead of dim
if (inputs[1]->isa<ValueNode>()) { if (inputs[1]->isa<ValueNode>()) {
auto axis_tensor = inputs[1]->cast<ValueNodePtr>(); auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
int axis = GetValue<int>(axis_tensor->value()); int axis = CastToInt(axis_tensor->value(), false).front();
attr->dim = axis; attr->dim = axis;
} else { } else {
MS_LOG(ERROR) << "input axis is not value node."; MS_LOG(ERROR) << "input axis is not value node.";

@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
} }
if (inputs[2]->isa<ValueNode>()) { if (inputs[2]->isa<ValueNode>()) {
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>(); ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
int axis = GetValue<int>(axis_tensor->value()); int axis = CastToInt(axis_tensor->value(), false).front();
gather_attr->axis = axis; gather_attr->axis = axis;
} else { } else {
MS_LOG(ERROR) << "input axis is not value node."; MS_LOG(ERROR) << "input axis is not value node.";

@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
} }
attr->axis = -1; attr->axis = -1;
if (prim.GetAttr("axis") != nullptr) { if (prim.GetAttr("axis") != nullptr) {
attr->axis = GetValue<int>(prim.GetAttr("axis")); attr->axis = CastToInt(prim.GetAttr("axis"), false).front();
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
attr->windowH = kernel_size[2]; attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3]; attr->windowW = kernel_size[3];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); auto stride = CastToInt(prim.GetAttr("strides"), true);
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
this->primitive_->value.value = attr; this->primitive_->value.value = attr;

@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
attr->windowH = kernel_size[2]; attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3]; attr->windowW = kernel_size[3];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); auto stride = CastToInt(prim.GetAttr("strides"), true);
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
this->primitive_->value.value = attr; this->primitive_->value.value = attr;

@ -180,6 +180,35 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
std::vector<int> CastToInt(const ValuePtr value, bool is_vector) {
if (value == nullptr) {
MS_LOG(WARNING) << "valueptr is nullptr.";
return {};
}
std::vector<int> cur_value;
if (is_vector) {
if (!utils::isa<ValueSequeuePtr>(value)) {
MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar.";
return {};
}
if (value->cast<ValueSequeuePtr>()->value().front()->type()->type_name() == "Int64Imm") {
auto origin_value = GetValue<std::vector<int64_t>>(value);
for (size_t index = 0; index < origin_value.size(); ++index) {
cur_value.push_back(static_cast<int>(origin_value[index]));
}
} else {
cur_value = GetValue<std::vector<int>>(value);
}
} else {
if (value->type_name() == "Int64Imm") {
cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
} else {
cur_value.push_back(GetValue<int>(value));
}
}
return cur_value;
}
void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0; const float qmin = 0;
const float qmax = 255; const float qmax = 255;

@ -52,6 +52,8 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU",
{"Sigmoid", schema::ActivationType_SIGMOID}, {"Sigmoid", schema::ActivationType_SIGMOID},
{"HSwish", schema::ActivationType_HSWISH}, {"HSwish", schema::ActivationType_HSWISH},
{"HSigmoid", schema::ActivationType_HSIGMOID}}; {"HSigmoid", schema::ActivationType_HSIGMOID}};
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
class PrimitiveC : public mindspore::Primitive { class PrimitiveC : public mindspore::Primitive {
public: public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

@ -87,7 +87,7 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
attr->axes.emplace_back(elem->value()); attr->axes.emplace_back(elem->value());
} }
} else { } else {
int axes_item = GetValue<int>(value); int axes_item = CastToInt(value, false).front();
attr->axes.push_back(axes_item); attr->axes.push_back(axes_item);
} }
} }

@ -63,7 +63,7 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
attr->shape.emplace_back(static_cast<int>(elem->value())); attr->shape.emplace_back(static_cast<int>(elem->value()));
} }
} else { } else {
int dim = GetValue<int>(val); int dim = CastToInt(val, false).front();
attr->shape = {dim}; attr->shape = {dim};
} }
} }

@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "wrong resize type"; MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size")); std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true);
attr->newHeight = targetSize[0]; attr->newHeight = targetSize[0];
attr->newWidth = targetSize[1]; attr->newWidth = targetSize[1];
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners")); attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));

@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(ERROR) << "new primitiveT value failed"; MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR; return RET_ERROR;
} }
auto prim_axis = GetValue<int>(prim.GetAttr("axis")); auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
attr->axis = prim_axis; attr->axis = prim_axis;
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) { if (this->primitive_->value.value == nullptr) {

@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "Squeeze's attr xis is set to default"; MS_LOG(INFO) << "Squeeze's attr xis is set to default";
attr->axis = {0}; attr->axis = {0};
} else { } else {
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); attr->axis = CastToInt(prim.GetAttr("axis"), true);
} }
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
} }

@ -73,11 +73,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
MS_LOG(ERROR) << "new StridedSlice failed"; MS_LOG(ERROR) << "new StridedSlice failed";
return RET_ERROR; return RET_ERROR;
} }
attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask")); attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front();
attr->endMask = GetValue<int>(prim.GetAttr("end_mask")); attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front();
attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask")); attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front();
attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask")); attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front();
attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask")); attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front();
auto inputNodeFirst = inputs[kAnfPopulaterOne]; auto inputNodeFirst = inputs[kAnfPopulaterOne];
std::vector<int> beginVec; std::vector<int> beginVec;
GetAttrDataFromInput(inputNodeFirst, &beginVec); GetAttrDataFromInput(inputNodeFirst, &beginVec);

@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(INFO) << "Tile's attr dims is set to default"; MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1}; attr->dims = {1};
} else { } else {
attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); attr->dims = CastToInt(prim.GetAttr("dims"), true);
} }
if (inputs.size() == kAnfPopulaterTwo) { if (inputs.size() == kAnfPopulaterTwo) {
auto inputNode = inputs[kAnfPopulaterOne]; auto inputNode = inputs[kAnfPopulaterOne];
@ -75,7 +75,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
attr->multiples.emplace_back(elem->value()); attr->multiples.emplace_back(elem->value());
} }
} else { } else {
int multiple = GetValue<int>(value); int multiple = CastToInt(value, false).front();
attr->multiples = {multiple}; attr->multiples = {multiple};
} }
} }

@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>(); std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
if (inputs[2]->isa<ValueNode>()) { if (inputs[2]->isa<ValueNode>()) {
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value(); ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
attr->numSegments = GetValue<int>(value); attr->numSegments = CastToInt(value, false).front();
this->primitive_->value.value = attr.release(); this->primitive_->value.value = attr.release();
} }
} }

@ -314,7 +314,9 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s
return RET_ERROR; return RET_ERROR;
} }
auto input_index_key = auto input_index_key =
get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue<int>(value_node->value())); get_item_input_cnode->fullname_with_scope() + "_o:" +
std::to_string(value_node->value()->type_name() == "Int64Imm" ? GetValue<int64_t>(value_node->value())
: GetValue<int>(value_node->value()));
auto iter = node_id_map_.find(input_index_key); auto iter = node_id_map_.find(input_index_key);
if (iter == node_id_map_.end()) { if (iter == node_id_map_.end()) {
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN

Loading…
Cancel
Save