|
|
|
@ -55,7 +55,8 @@ TransferSession::TransferSession(const char *model_buf_backbone, size_t size_bac
|
|
|
|
|
|
|
|
|
|
std::vector<tensor::MSTensor *> TransferSession::GetInputs() const { return combined_inputs_; }
|
|
|
|
|
|
|
|
|
|
bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask) {
|
|
|
|
|
bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask, size_t mask_len) {
|
|
|
|
|
MS_ASSERT(out->shape().size() == mask_len);
|
|
|
|
|
for (std::size_t dim = 0; dim != out->shape().size(); ++dim) {
|
|
|
|
|
if (in->shape().at(mask[dim]) != out->shape().at(dim)) {
|
|
|
|
|
return false;
|
|
|
|
@ -85,7 +86,7 @@ int TransferSession::CompileTransferGraph() {
|
|
|
|
|
}
|
|
|
|
|
if (match == false && input->shape().size() == 4) {
|
|
|
|
|
int nchw2nhwc_mask[4] = {0, 3, 1, 2};
|
|
|
|
|
nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask);
|
|
|
|
|
nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
|
|
|
|
|
match = nchw2nhwc_;
|
|
|
|
|
}
|
|
|
|
|
if (true == match) {
|
|
|
|
|