|
|
|
@ -22,6 +22,8 @@
|
|
|
|
|
#include "graph/preprocess/multi_batch_options.h"
|
|
|
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
|
#include "graph/utils/op_desc_utils.h"
|
|
|
|
|
#include "graph/utils/tensor_utils.h"
|
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
|
#include "register/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
@ -478,8 +480,28 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
|
|
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
(void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
|
|
|
|
|
|
|
|
|
|
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex));
|
|
|
|
|
std::vector<std::string> input_dims_str;
|
|
|
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) {
|
|
|
|
|
auto shape = data_shape;
|
|
|
|
|
auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str());
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
tensor.SetShape(shape);
|
|
|
|
|
int64_t tensor_size = 0;
|
|
|
|
|
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
|
|
|
|
|
string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
|
|
|
|
|
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" +
|
|
|
|
|
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
|
|
|
|
|
formats::JoinToString(tensor.GetShape().GetDims());
|
|
|
|
|
input_dims_str.emplace_back(input_str);
|
|
|
|
|
}
|
|
|
|
|
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
|
|
|
|
|
|
|
|
|
|
size_t max_shape_index = 0;
|
|
|
|
|
int64_t max_size = 0;
|
|
|
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) {
|
|
|
|
|