|
|
|
@ -50,7 +50,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() < (i + 1)) {
|
|
|
|
|
MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input_node = cnode->inputs()[i + 1];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<ValueNode>()) {
|
|
|
|
|
auto value_ptr = GetValueNode(input_node);
|
|
|
|
|
auto value = GetValue<std::string>(value_ptr);
|
|
|
|
@ -103,13 +109,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
|
|
|
|
|
output_size_list.push_back(IntToSize(size_i));
|
|
|
|
|
}
|
|
|
|
|
kernel_mod_ptr->SetOutputSizeList(output_size_list);
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
|
|
|
|
|
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_attr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
if (type == "int") {
|
|
|
|
|
auto attr_value = GetValue<int>(value);
|
|
|
|
|
(*node_attr)[attr_name].set_i(attr_value);
|
|
|
|
@ -146,6 +152,8 @@ void ParseAttrValue(const std::string &type, const std::string &attr_name, const
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(proto);
|
|
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
if (op_name == kInitDataSetQueue) {
|
|
|
|
|
op_name = kInitData;
|
|
|
|
@ -161,15 +169,16 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
|
|
|
|
for (const auto &attr_ptr : attrs_ptr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_ptr);
|
|
|
|
|
std::string attr_name = attr_ptr->name();
|
|
|
|
|
auto value = primitive->GetAttr(attr_name);
|
|
|
|
|
if (value != nullptr) {
|
|
|
|
|
if (attr_name == kQueueName || attr_name == kSharedName) {
|
|
|
|
|
attr_name = kChannelName;
|
|
|
|
|
} else if (attr_name == kSeed) {
|
|
|
|
|
attr_name = "seed";
|
|
|
|
|
} else if (attr_name == kSeed2) {
|
|
|
|
|
attr_name = "seed2";
|
|
|
|
|
} else if (attr_name == kSeed0) {
|
|
|
|
|
attr_name = kSeed;
|
|
|
|
|
} else if (attr_name == kSeed1) {
|
|
|
|
|
attr_name = kSeed2;
|
|
|
|
|
}
|
|
|
|
|
std::string type = attr_ptr->type();
|
|
|
|
|
ParseAttrValue(type, attr_name, value, node_attr);
|
|
|
|
@ -179,6 +188,8 @@ void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *p
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(proto);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
|
|
|
|
if (input_num == 0) {
|
|
|
|
|
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input.";
|
|
|
|
@ -193,6 +204,7 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|
|
|
|
int32_t input_data_type;
|
|
|
|
|
if (input_type == kObjectTypeString) {
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto input_node = cnode->inputs()[input_index + 1];
|
|
|
|
|
auto value_ptr = GetValueNode(input_node);
|
|
|
|
|
auto value = GetValue<std::string>(value_ptr);
|
|
|
|
@ -203,19 +215,20 @@ void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|
|
|
|
input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
|
|
|
|
|
input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape();
|
|
|
|
|
for (auto item : input_shape) {
|
|
|
|
|
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
|
|
|
|
dim->set_size((::google::protobuf::int64)item);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node_inputs->set_tensor_type((mindspore::DataType)input_data_type);
|
|
|
|
|
|
|
|
|
|
node_inputs->set_mem_device("HBM");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(proto);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
|
|
|
|
if (output_num == 0) {
|
|
|
|
|
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
|
|
|
|
@ -224,63 +237,55 @@ void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
|
|
|
|
|
|
|
|
|
|
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
|
|
|
|
::mindspore::Tensor *node_outputs = proto->add_outputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_outputs);
|
|
|
|
|
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
|
|
|
|
|
mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensorShape);
|
|
|
|
|
for (auto item : output_shape) {
|
|
|
|
|
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dim);
|
|
|
|
|
dim->set_size((::google::protobuf::int64)item);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index);
|
|
|
|
|
|
|
|
|
|
int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type);
|
|
|
|
|
node_outputs->set_tensor_type((mindspore::DataType)output_data_type);
|
|
|
|
|
|
|
|
|
|
node_outputs->set_mem_device("HBM");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
|
|
|
|
MS_LOG(INFO) << "SetNodedefProto entry";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(proto);
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "SetNodedefProto entry";
|
|
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
if (op_name == "InitDataSetQueue") {
|
|
|
|
|
op_name = "InitData";
|
|
|
|
|
if (op_name == kInitDataSetQueue) {
|
|
|
|
|
op_name = kInitData;
|
|
|
|
|
}
|
|
|
|
|
// set op name
|
|
|
|
|
proto->set_op(op_name);
|
|
|
|
|
|
|
|
|
|
// set inputs tensor
|
|
|
|
|
SetNodeInputs(anf_node, proto);
|
|
|
|
|
|
|
|
|
|
// set outputs tensor
|
|
|
|
|
SetNodeOutputs(anf_node, proto);
|
|
|
|
|
|
|
|
|
|
// set node attr
|
|
|
|
|
SetNodeAttr(anf_node, proto);
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "SetNodedefProto end!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
|
|
|
|
|
const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
|
|
|
|
|
MS_LOG(INFO) << "CreateNodeDefBytes entry";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
|
|
|
|
mindspore::NodeDef proto;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
MS_LOG(INFO) << "CreateNodeDefBytes entry";
|
|
|
|
|
|
|
|
|
|
mindspore::NodeDef proto;
|
|
|
|
|
SetNodedefProto(anf_node, &proto);
|
|
|
|
|
|
|
|
|
|
std::string nodeDefStr;
|
|
|
|
|
if (!proto.SerializeToString(&nodeDefStr)) {
|
|
|
|
|
MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetNodeDef(nodeDefStr);
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "CreateNodeDefBytes end!";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -288,8 +293,8 @@ bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
|
|
|
|
|
KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
if (op_name == "InitDataSetQueue") {
|
|
|
|
|
op_name = "InitData";
|
|
|
|
|
if (op_name == kInitDataSetQueue) {
|
|
|
|
|
op_name = kInitData;
|
|
|
|
|
}
|
|
|
|
|
auto kernel_mod_ptr = std::make_shared<AicpuOpKernelMod>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
|
|
|
|