|
|
@ -30,6 +30,7 @@ namespace mindspore {
|
|
|
|
namespace device {
|
|
|
|
namespace device {
|
|
|
|
namespace gpu {
|
|
|
|
namespace gpu {
|
|
|
|
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
|
|
|
|
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
|
|
|
|
|
|
|
|
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
|
|
|
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
|
|
|
|
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
|
|
|
|
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
|
|
|
|
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
|
|
|
|
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}},
|
|
|
|
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}},
|
|
|
@ -47,6 +48,8 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
|
|
|
{kFusedBatchNormGradEx, {{0, 1}, {0}}},
|
|
|
|
{kFusedBatchNormGradEx, {{0, 1}, {0}}},
|
|
|
|
{kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}},
|
|
|
|
{kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}},
|
|
|
|
{kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
|
|
|
|
{kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
|
|
|
|
|
|
|
|
{prim::kPrimConcat->name(), {{}, {0}}},
|
|
|
|
|
|
|
|
{prim::kPrimAddN->name(), {{}, {0}}},
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
|
|
|
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
|
|
|