|
|
|
@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -215,7 +215,8 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
|
|
|
|
|
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode);
|
|
|
|
|
if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
|
|
|
|
|
}
|
|
|
|
|
if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
|
|
|
|
@ -232,45 +233,11 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
|
|
|
|
|
return prim;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
|
|
|
|
if (!dropout_gen_mask->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
|
|
|
|
|
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr func_graph = cnode->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
FuncGraphManagerPtr manager = func_graph->manager();
|
|
|
|
|
if (manager == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValuePtr new_shape = MakeValue(input_slice_shape);
|
|
|
|
|
AnfNodePtr val = NewValueNode(new_shape);
|
|
|
|
|
(void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
|
|
|
|
|
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
|
|
|
|
|
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
|
|
|
|
|
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
|
|
|
|
|
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
|
|
|
|
std::vector<Operator> replace_ops;
|
|
|
|
|
Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
@ -293,20 +260,15 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP
|
|
|
|
|
if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
|
|
|
|
|
int32_t seed_0 = GetValue<int32_t>(attr[SEED0]);
|
|
|
|
|
int32_t seed_1 = GetValue<int32_t>(attr[SEED1]);
|
|
|
|
|
if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
|
|
|
|
|
seed_0 = SEED_NUM;
|
|
|
|
|
seed_1 = SEED_NUM;
|
|
|
|
|
SEED_NUM++;
|
|
|
|
|
} else {
|
|
|
|
|
SetGenMaskShape(cnode, input_slice_shape);
|
|
|
|
|
MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
|
|
|
|
|
return replace_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
|
|
|
|
|
ValuePtr new_shape = MakeValue(input_slice_shape);
|
|
|
|
|
Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
|
|
|
|
|
Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
|
|
|
|
@ -316,8 +278,7 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP
|
|
|
|
|
OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
|
|
|
|
|
replace_ops.push_back(replace_op);
|
|
|
|
|
return replace_ops;
|
|
|
|
|
return replace_op;
|
|
|
|
|
}
|
|
|
|
|
} // namespace parallel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|