|
|
|
@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->shrinkAxisMask = attr_value.i();
|
|
|
|
|
|
|
|
|
|
// begin
|
|
|
|
|
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
|
|
|
|
|
if (begin_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input begin failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto tensor_proto = attr_value.tensor();
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) {
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
|
|
|
|
attr->begin.push_back(tensor_proto.int_val(i));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) {
|
|
|
|
|
attr->begin.push_back(data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// end
|
|
|
|
|
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2));
|
|
|
|
|
if (end_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input end failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) {
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto = attr_value.tensor();
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) {
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
|
|
|
|
attr->end.push_back(tensor_proto.int_val(i));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) {
|
|
|
|
|
attr->end.push_back(data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// strides
|
|
|
|
|
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3));
|
|
|
|
|
if (stride_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input strides failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) {
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto = attr_value.tensor();
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) {
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
|
|
|
|
attr->stride.push_back(tensor_proto.int_val(i));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) {
|
|
|
|
|
attr->stride.push_back(data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_StridedSlice;
|
|
|
|
|
primitive->value.value = attr.release();
|
|
|
|
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
|
|
|
@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*output_size = 1;
|
|
|
|
|
auto status = AddOpInput(tf_op, 0, inputs);
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
for (int i = 0; i < tf_op.input_size(); i++) {
|
|
|
|
|
status = AddOpInput(tf_op, i, inputs);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Add Op input failed.";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser());
|
|
|
|
|