|
|
@ -235,7 +235,7 @@ size_t GetDtypeNbyte(const std::string &dtypes) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
|
|
|
|
bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
|
|
|
|
size_t builder_idex, const std::vector<int> &dyn_input_sizes,
|
|
|
|
size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
|
|
|
|
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
|
|
|
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
|
|
|
MS_EXCEPTION_IF_NULL(builder);
|
|
|
|
MS_EXCEPTION_IF_NULL(builder);
|
|
|
|
|
|
|
|
|
|
|
@ -262,7 +262,7 @@ bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
|
|
|
|
for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
|
|
|
|
kernel_info_index++;
|
|
|
|
kernel_info_index++;
|
|
|
|
auto type_id = DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
auto type_id = DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
inputs_device_type.push_back(type_id);
|
|
|
|
inputs_device_type.push_back(type_id);
|
|
|
@ -376,11 +376,11 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
|
|
|
|
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
|
|
|
|
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
|
|
|
|
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
|
|
|
|
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
|
|
|
|
std::vector<int> dyn_input_sizes;
|
|
|
|
std::vector<int64_t> dyn_input_sizes;
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
|
|
|
|
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
|
|
|
|
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
|
|
|
|
dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (inputs.size() > 0) {
|
|
|
|
if (inputs.size() > 0) {
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
|
@ -552,11 +552,11 @@ std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(cons
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> dyn_input_sizes;
|
|
|
|
std::vector<int64_t> dyn_input_sizes;
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
|
|
|
|
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
|
|
|
|
dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
|
|
|
|
dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (dyn_input_sizes.empty()) {
|
|
|
|
if (dyn_input_sizes.empty()) {
|
|
|
@ -764,28 +764,26 @@ bool IsWeightBoundary(const AnfNodePtr &node) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
|
|
|
|
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
|
|
|
|
if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
|
|
|
|
if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
|
|
|
|
AnfAlgo::GetInputTensorNum(cnode) != 1) {
|
|
|
|
AnfAlgo::GetInputTensorNum(cnode) != 1) {
|
|
|
|
MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
|
|
|
|
MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
|
|
|
|
<< "] is not single input or single output ";
|
|
|
|
<< "] is not single input or single output ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::vector<int> axis;
|
|
|
|
std::vector<int64_t> axis;
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
auto axis_attr = primitive->GetAttr(kAxis);
|
|
|
|
auto axis_attr = primitive->GetAttr(kAxis);
|
|
|
|
if (axis_attr == nullptr) {
|
|
|
|
if (axis_attr == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "This node does't have axie attr.";
|
|
|
|
MS_LOG(ERROR) << "This node does't have axie attr.";
|
|
|
|
return std::vector<int>();
|
|
|
|
return std::vector<int64_t>();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto type = axis_attr->type();
|
|
|
|
std::vector<int64_t> axis_list;
|
|
|
|
MS_EXCEPTION_IF_NULL(type);
|
|
|
|
if (axis_attr->isa<Int64Imm>()) {
|
|
|
|
std::vector<int> axis_list;
|
|
|
|
axis_list.emplace_back(GetValue<int64_t>(axis_attr));
|
|
|
|
if (type->ToString() == kTypeInt32) {
|
|
|
|
|
|
|
|
axis_list.emplace_back(GetValue<int>(axis_attr));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
axis_list = GetValue<std::vector<int>>(axis_attr);
|
|
|
|
axis_list = GetValue<std::vector<int64_t>>(axis_attr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (const auto &elem : axis_list) {
|
|
|
|
for (const auto &elem : axis_list) {
|
|
|
|
if (elem < 0) {
|
|
|
|
if (elem < 0) {
|
|
|
|