!1209 add format trans function

Merge pull request !1209 from liubuyu/master
pull/1209/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 828d0b124e

File diff suppressed because it is too large Load Diff

@ -61,6 +61,7 @@ bool TransFormat(const FormatArgs &args, void *result);
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
// host to device
bool NchwTo4D(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
@ -68,6 +69,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host
bool ToNchw(const FormatArgs &args, void *result);
bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);

@ -16,6 +16,7 @@
#include "device/ascend/ascend_device_address.h"
#include <memory>
#include <vector>
#include <set>
#include <algorithm>
#include "runtime/mem.h"
#include "device/kernel_runtime_manager.h"
@ -34,6 +35,10 @@ namespace device {
namespace ascend {
const int FLOAT_LEN = sizeof(float);
const int FLOAT16_LEN = 2; // sizeof(float16);
const std::set<std::string> kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0,
kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) {
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind);
if (ret_rt_memcpy != RT_ERROR_NONE) {
@ -97,7 +102,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) {
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (type_id_ == type) {
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);
sync_ok = true;
@ -115,9 +120,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
}
}
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
auto iter = kOpNeedTransFormat.find(format_);
if (iter != kOpNeedTransFormat.end()) {
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
} else {
MS_LOG(INFO) << "Can not find format transfer for :" << format_;
}
}
if (!sync_ok) {
@ -141,7 +148,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_FRAC_NZ) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);
@ -185,7 +192,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT) {
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (type_id_ == type) {
SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE);
sync_ok = true;
@ -203,9 +210,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
}
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
auto iter = kOpNeedTransFormat.find(format_);
if (iter != kOpNeedTransFormat.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
} else {
MS_LOG(INFO) << "Can not find format transfer for :" << format_;
}
}
if (!sync_ok) {
@ -227,7 +236,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);

@ -0,0 +1,113 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#include "common/trans.h"
#include "utils/utils.h"
using namespace std;
namespace mindspore {
namespace trans {
class FormatTransTest : public UT::Common {
public:
FormatTransTest() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(FormatTransTest, nchw_to_hwcn) {
uint16_t data[2*2*2*2] = {12581,14220,14937,14302,
15004,14951,14694,14564,
14069,14554,10507,14787,
13016,15263,14872,10838};
uint16_t res[2*2*2*2] = {12581,14069,15004,13016,
14220,14554,14951,15263,
14937,10507,14694,14872,
14302,14787,14564,10838};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, hwcn_to_nchw) {
uint16_t data[2*2*2*2] = {12581,14069,15004,13016,
14220,14554,14951,15263,
14937,10507,14694,14872,
14302,14787,14564,10838};
uint16_t res[2*2*2*2] = {12581,14220,14937,14302,
15004,14951,14694,14564,
14069,14554,10507,14787,
13016,15263,14872,10838};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, nchw_to_nhwc) {
uint16_t data[2*2*2*2] = {11750,13778,15007,15321,
15163,13446,15063,14467,
15056,13284,15219,14797,
12684,14288,14855,14799};
uint16_t res[2*2*2*2] = {11750,15163,13778,13446,
15007,15063,15321,14467,
15056,12684,13284,14288,
15219,14855,14797,14799};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
TEST_F(FormatTransTest, nhwc_to_nchw) {
uint16_t data[2*2*2*2] = {11750,15163,13778,13446,
15007,15063,15321,14467,
15056,12684,13284,14288,
15219,14855,14797,14799};
uint16_t res[2*2*2*2] = {11750,13778,15007,15321,
15163,13446,15063,14467,
15056,13284,15219,14797,
12684,14288,14855,14799};
size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
{2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
}
}
} // namespace trans
} // namespace mindspore
Loading…
Cancel
Save