|
|
|
@ -281,9 +281,11 @@ void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|
|
|
|
case DataType::DE_UINT64:
|
|
|
|
|
Cast<T, uint64_t>(input, output);
|
|
|
|
|
break;
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
case DataType::DE_FLOAT16:
|
|
|
|
|
Cast<T, float16>(input, output);
|
|
|
|
|
break;
|
|
|
|
|
#endif
|
|
|
|
|
case DataType::DE_FLOAT32:
|
|
|
|
|
Cast<T, float>(input, output);
|
|
|
|
|
break;
|
|
|
|
@ -328,9 +330,11 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|
|
|
|
case DataType::DE_UINT64:
|
|
|
|
|
CastFrom<uint64_t>(input, output);
|
|
|
|
|
break;
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
case DataType::DE_FLOAT16:
|
|
|
|
|
CastFrom<float16>(input, output);
|
|
|
|
|
break;
|
|
|
|
|
#endif
|
|
|
|
|
case DataType::DE_FLOAT32:
|
|
|
|
|
CastFrom<float>(input, output);
|
|
|
|
|
break;
|
|
|
|
@ -344,6 +348,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
|
|
|
|
// initiate new tensor for type cast
|
|
|
|
|
DataType new_type = DataType("float16");
|
|
|
|
@ -367,6 +372,9 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|
|
|
|
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { return Status::OK(); }
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
|
|
|
|
const std::shared_ptr<Tensor> &pad_val) {
|
|
|
|
@ -410,9 +418,13 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
|
|
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
|
|
|
|
|
} else if (tensor_type == DataType::DE_INT16) {
|
|
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
|
|
|
|
|
} else if (tensor_type == DataType::DE_FLOAT16) {
|
|
|
|
|
}
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
else if (tensor_type == DataType::DE_FLOAT16) { // NOLINT
|
|
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
|
|
|
|
|
} else if (tensor_type == DataType::DE_UINT16) {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
else if (tensor_type == DataType::DE_UINT16) { // NOLINT
|
|
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
|
|
|
|
|
} else if (tensor_type == DataType::DE_INT32) {
|
|
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
|
|
|
|
@ -570,9 +582,11 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|
|
|
|
case DataType::DE_INT64:
|
|
|
|
|
RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
|
|
|
|
|
break;
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
case DataType::DE_FLOAT16:
|
|
|
|
|
RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
|
|
|
|
|
break;
|
|
|
|
|
#endif
|
|
|
|
|
case DataType::DE_FLOAT32:
|
|
|
|
|
RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
|
|
|
|
|
break;
|
|
|
|
@ -732,6 +746,7 @@ struct UniqueOpHashMap<float16> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
struct gn_hash {
|
|
|
|
|
size_t operator()(const float16 &f) const { return static_cast<std::size_t>(f); }
|
|
|
|
|
};
|
|
|
|
@ -740,7 +755,7 @@ template <>
|
|
|
|
|
struct UniqueOpHashMap<float16> {
|
|
|
|
|
using map_type = std::unordered_map<float16, int32_t, gn_hash>;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -809,9 +824,13 @@ Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt));
|
|
|
|
|
} else if (input->type() == DataType::DE_UINT8) {
|
|
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt));
|
|
|
|
|
} else if (input->type() == DataType::DE_FLOAT16) {
|
|
|
|
|
}
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64
|
|
|
|
|
else if (input->type() == DataType::DE_FLOAT16) { // NOLINT
|
|
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt));
|
|
|
|
|
} else if (input->type() == DataType::DE_FLOAT32) {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
else if (input->type() == DataType::DE_FLOAT32) { // NOLINT
|
|
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt));
|
|
|
|
|
} else if (input->type() == DataType::DE_FLOAT64) {
|
|
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt));
|
|
|
|
|