!8481 Add dynamic shape support to GPU Reshape

From: @TFbunny
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
pull/8481/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8d6c780f93

@ -224,6 +224,8 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -416,5 +416,57 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
}
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto reshape = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
auto input_shp = input->shape()->shape();
auto reshape_val = reshape->BuildValue();
if (reshape_val->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Input_shape can't be anything: " << args_spec_list[1]->ToString();
}
auto reshape_val_data = reshape_val->cast<ValueTuplePtr>()->value();
ShapeVector reshape_vec;
(void)std::transform(std::begin(reshape_val_data), std::end(reshape_val_data), std::back_inserter(reshape_vec),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
ShapeVector result_shp;
auto input_prod = input_shp[0];
int64_t dim_prod = 1;
size_t neg_idx = 0;
for (size_t i = 1; i < input_shp.size(); i++) {
input_prod *= input_shp[i];
}
auto num_neg_one = count(std::begin(reshape_vec), std::end(reshape_vec), -1);
if (num_neg_one > 1) {
MS_LOG(EXCEPTION) << "The shape can only has one -1 at most, but " << num_neg_one;
}
for (size_t i = 0; i < reshape_vec.size(); i++) {
if (reshape_vec[i] == -1) {
neg_idx = i;
result_shp.push_back(-1);
} else {
dim_prod *= reshape_vec[i];
result_shp.push_back(reshape_vec[i]);
}
}
if (dim_prod < 0 || input_prod % dim_prod != 0) {
MS_LOG(EXCEPTION) << "The input_x shape product is " << input_prod << ", input_shape shape product is " << dim_prod
<< ", and this value should be > 0 and should divide product of input_x.";
}
if (num_neg_one == 1) {
int64_t val = static_cast<int64_t>(input_prod) / dim_prod;
dim_prod *= val;
result_shp[neg_idx] = val;
}
if (dim_prod != input_prod) {
MS_LOG(EXCEPTION)
<< "The product of input_x shape should be equal to product of input_shape shape, but input_x shape is "
<< input_prod << ", product of input_shape shape is " << dim_prod;
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
}
} // namespace abstract
} // namespace mindspore

@ -61,6 +61,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

Loading…
Cancel
Save