|
|
|
@ -42,8 +42,17 @@ bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) {
|
|
|
|
|
MS_LOG(INFO) << "stream switch op init start";
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast<CNodePtr>())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition";
|
|
|
|
|
}
|
|
|
|
|
cond_ = tagRtCondition(GetValue<int>(primitive->GetAttr(kAttrSwitchCondition)));
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast<CNodePtr>())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream";
|
|
|
|
|
}
|
|
|
|
|
true_stream_index_ = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast<CNodePtr>())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType";
|
|
|
|
|
}
|
|
|
|
|
data_type_ = tagRtSwitchDataType(GetValue<int>(primitive->GetAttr(kAttrDataType)));
|
|
|
|
|
MS_LOG(INFO) << "cond_:" << static_cast<int>(cond_) << ", true_stream_index_:" << true_stream_index_
|
|
|
|
|
<< ", data_type_:" << static_cast<int>(data_type_);
|
|
|
|
@ -54,7 +63,7 @@ bool StreamSwitchKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
|
|
|
|
MS_LOG(INFO) << "stream switch op launch start";
|
|
|
|
|
if (inputs.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "Stream switch inputs size is " << inputs.size() << ", only support 2";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *loop_cnt = inputs[0]->addr;
|
|
|
|
@ -73,7 +82,7 @@ std::vector<TaskInfoPtr> StreamSwitchKernel::GenTask(const std::vector<AddressPt
|
|
|
|
|
uint32_t stream_id) {
|
|
|
|
|
MS_LOG(INFO) << "StreamSwitchKernel GenTask start";
|
|
|
|
|
if (inputs.size() != 2) {
|
|
|
|
|
MS_LOG(ERROR) << "stream switch inputs size is " << inputs.size() << ", is not two";
|
|
|
|
|
MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two";
|
|
|
|
|
}
|
|
|
|
|
stream_id_ = stream_id;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
|
|
|