!3135 GPU cast support more type

Merge pull request !3135 from VectorSL/cast2
pull/3135/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 11732f0ea2

@ -18,6 +18,7 @@
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
@ -75,15 +76,7 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) {
std::string dst_type;
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0);
if (output_type == kFloat32->type_id()) {
dst_type = "float32";
} else if (output_type == kFloat16->type_id()) {
dst_type = "float16";
} else if (output_type == kInt32->type_id()) {
dst_type = "int32";
} else {
MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString();
}
dst_type = TypeId2String(output_type);
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
}

@ -21,10 +21,39 @@ cast_op_info = AkgGpuRegOp("Cast") \
.output(0, "output") \
.attr("dst_type", "required", "str") \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.get_op_info()

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save