bug fix for 3d format

pull/13002/head
liubuyu 4 years ago
parent 659b912f6d
commit 62aa7d0e87

@ -241,8 +241,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
return true;
}
// not support format:
// 1 NCDHW with shape size != 5
if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) {
// 1 3d formats with shape size > 5
if (k3DFormatSet.find(format) != k3DFormatSet.end() && shape.size() > kShape5dDims) {
return false;
}
return true;

@ -516,7 +516,8 @@ const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
kPadAndShiftOpName, kCTCGreedyDecoderOpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName,

Loading…
Cancel
Save