|
|
|
@ -363,9 +363,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// we set special device info of a input tensor.
|
|
|
|
|
bool is_ref = false;
|
|
|
|
|
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
|
|
|
|
@ -376,9 +373,12 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
}
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
@ -386,7 +386,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
}
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|