!13582 【MD】【BUG】fix MD lite train error

From: @xulei2020
Reviewed-by: 
Signed-off-by:
pull/13582/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bfcf4562ad

@ -257,10 +257,12 @@ inline DataType DataType::FromCType<float>() {
return DataType(DataType::DE_FLOAT32);
}
#ifndef ENABLE_MD_LITE_X86_64
template <>
inline DataType DataType::FromCType<float16>() {
return DataType(DataType::DE_FLOAT16);
}
#endif
template <>
inline DataType DataType::FromCType<int64_t>() {
@ -327,10 +329,12 @@ inline bool DataType::IsLooselyCompatible<float>() const {
return type_ == DataType::DE_FLOAT32;
}
#ifndef ENABLE_MD_LITE_X86_64
template <>
inline bool DataType::IsLooselyCompatible<float16>() const {
return type_ == DataType::DE_FLOAT16;
}
#endif
template <>
inline bool DataType::IsLooselyCompatible<int64_t>() const {

@ -396,9 +396,9 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c
CASE_PRINT(DataType::DE_INT64, int64_t)
CASE_PRINT(DataType::DE_UINT64, uint64_t)
#ifndef ENABLE_MD_LITE_X86_64
CASE_PRINT(DataType::DE_FLOAT16, float16)
#endif
CASE_PRINT(DataType::DE_FLOAT32, float)
CASE_PRINT(DataType::DE_FLOAT64, double)
@ -825,12 +825,14 @@ Status Tensor::GetFloatAt(T *o, const std::vector<dsize_t> &index) const {
RETURN_STATUS_UNEXPECTED(err);
}
switch (type_.value()) {
#ifndef ENABLE_MD_LITE_X86_64
case DataType::DE_FLOAT16: {
float16 *ptr = nullptr;
RETURN_IF_NOT_OK(GetItemPtr<float16>(&ptr, index));
*o = static_cast<T>(*ptr);
break;
}
#endif
case DataType::DE_FLOAT32: {
float *ptr = nullptr;
RETURN_IF_NOT_OK(GetItemPtr<float>(&ptr, index));

@ -281,9 +281,11 @@ void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
case DataType::DE_UINT64:
Cast<T, uint64_t>(input, output);
break;
#ifndef ENABLE_MD_LITE_X86_64
case DataType::DE_FLOAT16:
Cast<T, float16>(input, output);
break;
#endif
case DataType::DE_FLOAT32:
Cast<T, float>(input, output);
break;
@ -328,9 +330,11 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
case DataType::DE_UINT64:
CastFrom<uint64_t>(input, output);
break;
#ifndef ENABLE_MD_LITE_X86_64
case DataType::DE_FLOAT16:
CastFrom<float16>(input, output);
break;
#endif
case DataType::DE_FLOAT32:
CastFrom<float>(input, output);
break;
@ -344,6 +348,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
return Status::OK();
}
#ifndef ENABLE_MD_LITE_X86_64
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
// initiate new tensor for type cast
DataType new_type = DataType("float16");
@ -367,6 +372,9 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
return Status::OK();
}
#else
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { return Status::OK(); }
#endif
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
const std::shared_ptr<Tensor> &pad_val) {
@ -410,9 +418,13 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
} else if (tensor_type == DataType::DE_INT16) {
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
} else if (tensor_type == DataType::DE_FLOAT16) {
}
#ifndef ENABLE_MD_LITE_X86_64
else if (tensor_type == DataType::DE_FLOAT16) { // NOLINT
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
} else if (tensor_type == DataType::DE_UINT16) {
}
#endif
else if (tensor_type == DataType::DE_UINT16) { // NOLINT
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
} else if (tensor_type == DataType::DE_INT32) {
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
@ -570,9 +582,11 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
case DataType::DE_INT64:
RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
break;
#ifndef ENABLE_MD_LITE_X86_64
case DataType::DE_FLOAT16:
RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
break;
#endif
case DataType::DE_FLOAT32:
RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
break;
@ -732,6 +746,7 @@ struct UniqueOpHashMap<float16> {
};
#else
#ifndef ENABLE_MD_LITE_X86_64
struct gn_hash {
size_t operator()(const float16 &f) const { return static_cast<std::size_t>(f); }
};
@ -740,7 +755,7 @@ template <>
struct UniqueOpHashMap<float16> {
using map_type = std::unordered_map<float16, int32_t, gn_hash>;
};
#endif
#endif
template <>
@ -809,9 +824,13 @@ Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_UINT8) {
RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT16) {
}
#ifndef ENABLE_MD_LITE_X86_64
else if (input->type() == DataType::DE_FLOAT16) { // NOLINT
RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT32) {
}
#endif
else if (input->type() == DataType::DE_FLOAT32) { // NOLINT
RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt));
} else if (input->type() == DataType::DE_FLOAT64) {
RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt));

@ -23,6 +23,7 @@
using float16 = float16_t;
inline float half_to_float(float16 h) { return static_cast<float>(h); }
#else
#ifndef ENABLE_MD_LITE_X86_64
#include <functional>
#include "Eigen/Core"
@ -30,4 +31,5 @@ using float16 = Eigen::half;
using HalfToFloat = std::function<float(float16)>;
const inline HalfToFloat half_to_float = Eigen::half_impl::half_to_float;
#endif
#endif
#endif // MINDSPORE_CORE_BASE_FLOAT16_H_

@ -260,6 +260,9 @@ endif()
if(BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full" OR BUILD_MINDDATA STREQUAL "wrapper")
add_compile_definitions(ENABLE_ANDROID)
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
add_compile_definitions(ENABLE_MD_LITE_X86_64)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/minddata)
endif()

@ -20,11 +20,14 @@
#include <unordered_map>
#include <functional>
#include <string>
#include "Eigen/Core"
#include "ops/fusion/conv2d_fusion.h"
#include "src/common/common.h"
#include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h"
using float16 = Eigen::half;
namespace mindspore {
namespace opt {
namespace {

Loading…
Cancel
Save