From b69c2003316e89a275d708a3ef99ac073dca1c38 Mon Sep 17 00:00:00 2001 From: zjun Date: Mon, 15 Jun 2020 10:05:40 +0800 Subject: [PATCH] fix code review --- .../ccsrc/kernel/aicpu/aicpu_kernel_build.cc | 57 ++++++++++--------- .../ccsrc/kernel/aicpu/aicpu_kernel_mod.cc | 8 +-- mindspore/ccsrc/kernel/aicpu/aicpu_util.h | 10 ++-- mindspore/ccsrc/kernel/oplib/opinfo.h | 8 --- mindspore/ccsrc/kernel/oplib/oplib.cc | 3 + 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index d6217ff1cc..1afe01bd6a 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -50,7 +50,13 @@ bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input MS_LOG(EXCEPTION) << "anf_node is not CNode."; } auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < (i + 1)) { + MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; + return false; + } auto input_node = cnode->inputs()[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa()) { auto value_ptr = GetValueNode(input_node); auto value = GetValue(value_ptr); @@ -103,13 +109,13 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptrSetOutputSizeList(output_size_list); - return true; } void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { MS_EXCEPTION_IF_NULL(node_attr); + MS_EXCEPTION_IF_NULL(value); if (type == "int") { auto attr_value = GetValue(value); (*node_attr)[attr_name].set_i(attr_value); @@ -146,6 +152,8 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const } void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(proto); std::string op_name = AnfAlgo::GetCNodeName(anf_node); if (op_name == kInitDataSetQueue) { op_name = kInitData; @@ -161,15 +169,16 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p MS_EXCEPTION_IF_NULL(primitive); ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); for (const auto &attr_ptr : attrs_ptr) { + MS_EXCEPTION_IF_NULL(attr_ptr); std::string attr_name = attr_ptr->name(); auto value = primitive->GetAttr(attr_name); if (value != nullptr) { if (attr_name == kQueueName || attr_name == kSharedName) { attr_name = kChannelName; - } else if (attr_name == kSeed) { - attr_name = "seed"; - } else if (attr_name == kSeed2) { - attr_name = "seed2"; + } else if (attr_name == kSeed0) { + attr_name = kSeed; + } else if (attr_name == kSeed1) { + attr_name = kSeed2; } std::string type = attr_ptr->type(); ParseAttrValue(type, attr_name, value, node_attr); @@ -179,6 +188,8 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p } void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); if (input_num == 0) { MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; @@ -193,6 +204,7 @@ void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef int32_t input_data_type; if (input_type == kObjectTypeString) { auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); auto input_node = cnode->inputs()[input_index + 1]; auto value_ptr = GetValueNode(input_node); auto value = GetValue(value_ptr); @@ -203,19 +215,20 @@ void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); } + mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); for (auto item : input_shape) { mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); dim->set_size((::google::protobuf::int64)item); } - node_inputs->set_tensor_type((mindspore::DataType)input_data_type); - node_inputs->set_mem_device("HBM"); } } void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); if (output_num == 0) { MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; @@ -224,63 +237,55 @@ void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef for (size_t output_index = 0; output_index < output_num; output_index++) { ::mindspore::Tensor *node_outputs = proto->add_outputs(); + MS_EXCEPTION_IF_NULL(node_outputs); std::vector output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); + MS_EXCEPTION_IF_NULL(tensorShape); for (auto item : output_shape) { mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + MS_EXCEPTION_IF_NULL(dim); dim->set_size((::google::protobuf::int64)item); } - TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); - int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); node_outputs->set_tensor_type((mindspore::DataType)output_data_type); - node_outputs->set_mem_device("HBM"); } } void SetNodedefProto(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_LOG(INFO) << "SetNodedefProto entry"; MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(proto); - + MS_LOG(INFO) << "SetNodedefProto entry"; std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == "InitDataSetQueue") { - op_name = "InitData"; + if (op_name == kInitDataSetQueue) { + op_name = kInitData; } // set op name proto->set_op(op_name); - // set inputs tensor SetNodeInputs(anf_node, proto); - // set outputs tensor SetNodeOutputs(anf_node, proto); - // set node attr SetNodeAttr(anf_node, proto); - MS_LOG(INFO) << "SetNodedefProto end!"; } bool CreateNodeDefBytes(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { - MS_LOG(INFO) << "CreateNodeDefBytes entry"; - MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - mindspore::NodeDef proto; + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "CreateNodeDefBytes entry"; + mindspore::NodeDef proto; SetNodedefProto(anf_node, &proto); - std::string nodeDefStr; if (!proto.SerializeToString(&nodeDefStr)) { MS_LOG(ERROR) << "Serialize nodeDef to string failed."; return false; } - kernel_mod_ptr->SetNodeDef(nodeDefStr); - MS_LOG(INFO) << "CreateNodeDefBytes end!"; return true; } @@ -288,8 +293,8 @@ bool CreateNodeDefBytes(const std::shared_ptr &anf_node, KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == "InitDataSetQueue") { - op_name = "InitData"; + if (op_name == kInitDataSetQueue) { + op_name = kInitData; } auto kernel_mod_ptr = std::make_shared(); MS_EXCEPTION_IF_NULL(kernel_mod_ptr); diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc index 7875baaf0e..2213f176cc 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc @@ -110,8 +110,8 @@ bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std:: } CreateCpuKernelInfo(inputs, outputs); - if (node_name_ == "TopK") { - node_name_ = "TopKV2"; + if (node_name_ == kTopK) { + node_name_ = kTopKV2; } MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ << ", args_size:" << args_.length(); @@ -141,8 +141,8 @@ std::vector AicpuOpKernelMod::GenTask(const std::vector (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), [](const AddressPtr &output) -> void * { return output->addr; }); - if (node_name_ == "TopK") { - node_name_ = "TopKV2"; + if (node_name_ == kTopK) { + node_name_ = kTopKV2; } AicpuTaskInfoPtr task_info_ptr = make_shared( stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs); diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/kernel/aicpu/aicpu_util.h index b6f43414e3..22d41d3dd9 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_util.h @@ -28,7 +28,6 @@ constexpr auto kInitDataSetQueue = "InitDataSetQueue"; constexpr auto kInitData = "InitData"; constexpr auto kGetNext = "GetNext"; constexpr auto kPrint = "Print"; - constexpr auto kOutputTypes = "output_types"; constexpr auto kOutputShapes = "output_shapes"; constexpr auto kChannelName = "channel_name"; @@ -36,9 +35,12 @@ constexpr auto kSharedName = "shared_name"; constexpr auto kShapes = "shapes"; constexpr auto kTypes = "types"; constexpr auto kQueueName = "queue_name"; - -constexpr auto kSeed = "Seed0"; -constexpr auto kSeed2 = "Seed1"; +constexpr auto kSeed = "seed"; +constexpr auto kSeed0 = "Seed0"; +constexpr auto kSeed1 = "Seed1"; +constexpr auto kSeed2 = "seed2"; +constexpr auto kTopK = "TopK"; +constexpr auto kTopKV2 = "TopKV2"; struct AicpuParamHead { uint32_t length; // Total length: include cunstom message diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 4d13308598..bb8defe74d 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -95,12 +95,7 @@ class OpInfo { OpImplyType imply_type() const { return imply_type_; } std::string impl_path() const { return impl_path_; } std::string fusion_type() const { return fusion_type_; } - bool async_flag() const { return async_flag_; } - std::string binfile_name() const { return binfile_name_; } - int compute_cost() const { return compute_cost_; } std::string kernel_name() const { return kernel_name_; } - bool partial_flag() const { return partial_flag_; } - bool dynamic_format() const { return dynamic_format_; } OpPattern op_pattern() const { return op_pattern_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } @@ -116,13 +111,10 @@ class OpInfo { void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } - void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } - void set_inputs_ptr(const std::vector> &inputs) { inputs_ptr_ = inputs; } - void set_outputs_ptr(const std::vector> &outputs) { outputs_ptr_ = outputs; } bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index 1b367120b4..42ec534ae0 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -103,6 +103,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}, {kDynamicFormat, kDynamicFormatPattern}}; + MS_EXCEPTION_IF_NULL(op_info); op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_compute_cost(obj.at(kComputeCost)); @@ -199,6 +200,7 @@ bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index) { + MS_EXCEPTION_IF_NULL(op_io); bool ret = true; try { std::vector dtype; @@ -218,6 +220,7 @@ bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::sha bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { + MS_EXCEPTION_IF_NULL(op_info); bool ret = true; try { std::shared_ptr op_io = std::make_shared();