!8909 [lite] fix tflite deconv parser and adjust mindir int64

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/8909/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ea5bacec49

@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
attr->axis = {1};
} else {
attr->axis = CastToInt(prim.GetAttr("axis"), true);
attr->axis = CastToInt(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
}

@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(WARNING) << "get axis failed";
attr->axis = {0};
} else {
attr->axis = CastToInt(prim.GetAttr("axis"), true);
attr->axis = CastToInt(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
}

@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
auto prim_axis = CastToInt(prim.GetAttr("axis")).front();
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {

@ -143,21 +143,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
@ -179,7 +179,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
}
attr->channelMultiplier = channel_mutiplier;
@ -220,25 +220,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[2];
attr->dilateW = dilation[3];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
@ -278,7 +278,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
return RET_NULL_PTR;
}
int group = CastToInt(groupAttr, false).front();
int group = CastToInt(groupAttr).front();
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {

@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->group = CastToInt(prim.GetAttr("group"), false).front();
attr->group = CastToInt(prim.GetAttr("group")).front();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
@ -154,7 +154,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem, false).front();
attr->filter_shape[nchw2nhwc[i]] = CastToInt(elem).front();
}
}
}

@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->group = CastToInt(prim.GetAttr("group"), false).front();
attr->group = CastToInt(prim.GetAttr("group")).front();
if (attr->group > 1) {
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
}
@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
@ -156,7 +156,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->input_shape[nchw2nhwc[i]] = CastToInt(elem, false).front();
attr->input_shape[nchw2nhwc[i]] = CastToInt(elem).front();
}
}
}

@ -136,21 +136,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
@ -172,7 +172,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
}
attr->channelMultiplier = channel_mutiplier;
@ -203,25 +203,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid" || pad_mode == "VALID") {
@ -256,7 +256,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = CastToInt(prim.GetAttr("group"), false).front();
int group = CastToInt(prim.GetAttr("group")).front();
if (group == 1) {
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
} else if (group > 1) {

@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = CastToInt(prim.GetAttr("pads"), true);
auto pad_list = CastToInt(prim.GetAttr("pads"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
auto dilation = CastToInt(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
} else {
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front();
auto kernel_size = CastToInt(prim.GetAttr("kernel_size")).front();
attr->kernelH = kernel_size;
attr->kernelW = kernel_size;
}
auto stride = CastToInt(prim.GetAttr("stride"), true);
auto stride = CastToInt(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
attr->channelMultiplier = channel_multiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);

@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
// use axis instead of dim
if (inputs[1]->isa<ValueNode>()) {
auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
int axis = CastToInt(axis_tensor->value(), false).front();
int axis = CastToInt(axis_tensor->value()).front();
attr->dim = axis;
} else {
MS_LOG(ERROR) << "input axis is not value node.";

@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
if (inputs[2]->isa<ValueNode>()) {
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
int axis = CastToInt(axis_tensor->value(), false).front();
int axis = CastToInt(axis_tensor->value()).front();
gather_attr->axis = axis;
} else {
MS_LOG(ERROR) << "input axis is not value node.";

@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
attr->axis = -1;
if (prim.GetAttr("axis") != nullptr) {
attr->axis = CastToInt(prim.GetAttr("axis"), false).front();
attr->axis = CastToInt(prim.GetAttr("axis")).front();
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {

@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
attr->padMode = schema::PadMode_NOTSET;
}
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
auto kernel_size = CastToInt(prim.GetAttr("ksize"));
attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3];
auto stride = CastToInt(prim.GetAttr("strides"), true);
auto stride = CastToInt(prim.GetAttr("strides"));
attr->strideH = stride[2];
attr->strideW = stride[3];
this->primitive_->value.value = attr;

@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->padMode = schema::PadMode_NOTSET;
}
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
auto kernel_size = CastToInt(prim.GetAttr("ksize"));
attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3];
auto stride = CastToInt(prim.GetAttr("strides"), true);
auto stride = CastToInt(prim.GetAttr("strides"));
attr->strideH = stride[2];
attr->strideW = stride[3];
this->primitive_->value.value = attr;

@ -181,17 +181,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> CastToInt(const ValuePtr value, bool is_vector) {
std::vector<int> CastToInt(const ValuePtr value) {
if (value == nullptr) {
MS_LOG(WARNING) << "valueptr is nullptr.";
return {};
}
std::vector<int> cur_value;
if (is_vector) {
if (!utils::isa<ValueSequeuePtr>(value)) {
MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar.";
return {};
}
if (utils::isa<ValueSequeuePtr>(value)) {
if (value->cast<ValueSequeuePtr>()->value().front()->type()->number_type() == kNumberTypeInt64) {
auto origin_value = GetValue<std::vector<int64_t>>(value);
for (size_t index = 0; index < origin_value.size(); ++index) {
@ -337,7 +333,7 @@ void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<i
for (size_t i = 0; i < tuple->size(); i++) {
auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr);
data->emplace_back(CastToInt(elem, false).front());
data->emplace_back(CastToInt(elem).front());
}
}
}

@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
std::vector<int> CastToInt(const ValuePtr value);
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

@ -84,10 +84,10 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->axes.emplace_back(CastToInt(elem, false).front());
attr->axes.emplace_back(CastToInt(elem).front());
}
} else {
int axes_item = CastToInt(value, false).front();
int axes_item = CastToInt(value).front();
attr->axes.push_back(axes_item);
}
}

@ -60,10 +60,10 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
for (size_t i = 0; i < tuple->size(); ++i) {
auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr);
attr->shape.emplace_back(CastToInt(elem, false).front());
attr->shape.emplace_back(CastToInt(elem).front());
}
} else {
int dim = CastToInt(val, false).front();
int dim = CastToInt(val).front();
attr->shape = {dim};
}
}

@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR;
}
std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true);
std::vector<int> targetSize = CastToInt(prim.GetAttr("size"));
attr->newHeight = targetSize[0];
attr->newWidth = targetSize[1];
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));

@ -73,7 +73,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->begin.emplace_back(CastToInt(elem, false).front());
attr->begin.emplace_back(CastToInt(elem).front());
}
}
}
@ -90,7 +90,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->size.emplace_back(CastToInt(elem, false).front());
attr->size.emplace_back(CastToInt(elem).front());
}
}
}

@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
auto prim_axis = CastToInt(prim.GetAttr("axis")).front();
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {

@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
MS_LOG(INFO) << "Squeeze's attr xis is set to default";
attr->axis = {0};
} else {
attr->axis = CastToInt(prim.GetAttr("axis"), true);
attr->axis = CastToInt(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
}

@ -74,11 +74,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
MS_LOG(ERROR) << "new StridedSlice failed";
return RET_ERROR;
}
attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front();
attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front();
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front();
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front();
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front();
attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front();
attr->endMask = CastToInt(prim.GetAttr("end_mask")).front();
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front();
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front();
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front();
auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne];
std::vector<int> beginVec;
GetAttrDataFromInput(inputNodeFirst, &beginVec);

@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1};
} else {
attr->dims = CastToInt(prim.GetAttr("dims"), true);
attr->dims = CastToInt(prim.GetAttr("dims"));
}
if (inputs.size() == kAnfPopulaterInputNumTwo) {
auto inputNode = inputs[kAnfPopulaterInputNumOne];
@ -72,10 +72,10 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = (*valTuplPtr)[i];
MS_ASSERT(elem != nullptr);
attr->multiples.emplace_back(CastToInt(elem, false).front());
attr->multiples.emplace_back(CastToInt(elem).front());
}
} else {
int multiple = CastToInt(value, false).front();
int multiple = CastToInt(value).front();
attr->multiples = {multiple};
}
}

@ -64,7 +64,7 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
for (size_t i = 0; i < tuple->size(); i++) {
auto elem = tuple->value()[i];
MS_ASSERT(elem != nullptr);
attr->perm.emplace_back(CastToInt(elem, false).front());
attr->perm.emplace_back(CastToInt(elem).front());
}
}
}

@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
if (inputs[2]->isa<ValueNode>()) {
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
attr->numSegments = CastToInt(value, false).front();
attr->numSegments = CastToInt(value).front();
this->primitive_->value.value = attr.release();
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save