stride_slice-5

pull/11365/head
yefeng 4 years ago
parent f679fcf075
commit 152992d3a9

@ -359,7 +359,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
MS_ASSERT(input != nullptr);
auto input_shape = input->shape();
auto inferflag = infer_flag();
if (!infer_flag()) {
return RET_INFER_INVALID;
}
in_shape_.clear();
if (inferflag) {
in_shape_.assign(input_shape.begin(), input_shape.end());

@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_ERROR;
}
attr->shrinkAxisMask = attr_value.i();
// begin
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (begin_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input begin failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->begin.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->begin.push_back(data[i]);
}
}
// end
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2));
if (end_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input end failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->end.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->end.push_back(data[i]);
}
}
// strides
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3));
if (stride_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input strides failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->stride.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->stride.push_back(data[i]);
}
}
primitive->value.type = schema::PrimitiveType_StridedSlice;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
STATUS status = RET_OK;
for (int i = 0; i < tf_op.input_size(); i++) {
status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add Op input failed.";
return status;
}
}
return status;
}
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser());

@ -71,8 +71,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
auto fw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
auto fw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape});
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
fw_shape, std::make_shared<SeqVar>()});
auto fw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});
@ -106,8 +106,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
bw_reverse_seq, std::make_shared<Var>()});
auto bw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
auto bw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape});
auto bw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
bw_shape, std::make_shared<SeqVar>()});
auto bw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
auto bw_reserve =

Loading…
Cancel
Save