!8314 Add dynamic shape support to GPU Transpose

From: @TFbunny
Reviewed-by: @robingrosman
Signed-off-by:
pull/8314/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5caded733e

@ -222,7 +222,8 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <set>
#include <algorithm>
#include <iterator>
#include "abstract/infer_functions.h"
@ -385,5 +386,35 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto perm = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
auto input_shp = input->shape()->shape();
auto perm_val = perm->BuildValue();
if (perm_val->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Perm can't be anything: " << args_spec_list[1]->ToString();
}
auto perm_val_data = perm_val->cast<ValueTuplePtr>()->value();
ShapeVector perm_vec;
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
ShapeVector result_shp;
std::set<size_t> indices;
for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
if (indices.find(idx) != indices.end()) {
MS_LOG(EXCEPTION) << "Perm values must be unique";
}
if (idx >= perm_vec.size()) {
MS_LOG(EXCEPTION) << "One value in perm is " << idx << ", not in range [0, " << perm_vec.size() << ")";
}
result_shp.push_back(input_shp[idx]);
indices.insert(idx);
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
}
} // namespace abstract
} // namespace mindspore

@ -60,6 +60,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

@ -589,7 +589,7 @@ class Squeeze(PrimitiveWithInfer):
return x_dtype
class Transpose(PrimitiveWithInfer):
class Transpose(PrimitiveWithCheck):
"""
Permutes the dimensions of input tensor according to input permutation.
@ -621,32 +621,13 @@ class Transpose(PrimitiveWithInfer):
"""Initialize Transpose"""
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
def __infer__(self, x, perm):
x_shape = x['shape']
p_value = perm['value']
x_type = x['dtype']
validator.check_value_type("p_value", p_value, [tuple], self.name)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
if len(x_shape) != len(p_value):
def check_shape(self, x, perm):
validator.check_value_type("perm", perm, [tuple], self.name)
if len(x) != len(perm):
raise ValueError('The dimension of x and perm must be equal.')
tmp = list(p_value)
for i, dim in enumerate(p_value):
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim)
if dim in tmp:
raise ValueError('The value of perm is wrong.')
out_shapes = []
for i in p_value:
out_shapes.append(x_shape[i])
out = {'shape': tuple(out_shapes),
'dtype': x['dtype'],
'value': None}
return out
def check_dtype(self, x, perm):
validator.check_subclass("x", x, mstype.tensor, self.name)
class Unique(Primitive):
"""

Loading…
Cancel
Save