fix dropout unify_mindir pass

pull/10238/head
jjfeing 4 years ago committed by yuchaojie
parent b41d83a7df
commit 389da54525

@ -42,6 +42,24 @@ class DropoutGradUnifyMindIR : public PatternProcessPass {
private:
VarPtr grad_input_;
};
class DropoutUnifyMindIRPynative : public PatternProcessPass {
public:
explicit DropoutUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {}
~DropoutUnifyMindIRPynative() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class DropoutGradUnifyMindIRPynative : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {}
~DropoutGradUnifyMindIRPynative() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_

@ -444,6 +444,15 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
} else {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
}
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(graph);

@ -1633,7 +1633,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
UnifyMindIR(graph);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
UnifyMindIR(graph);
}
return graph;
}

@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator
from mindspore import context
from ..cell import Cell
from .activation import get_activation
@ -146,33 +145,17 @@ class Dropout(Cell):
seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0
self.seed1 = seed1
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_ascend = context.get_context('device_target') in ["Ascend"]
self.dropout = P.Dropout(keep_prob)
self.dropout = P.Dropout(keep_prob, seed0, seed1)
def construct(self, x):
if not self.training:
return x
if not self.is_ascend:
out, _ = self.dropout(x)
return out
if self.keep_prob == 1:
return x
shape = self.get_shape(x)
dtype = P.DType()(x)
if _is_float_dtype(dtype):
keep_prob = self.cast(self.keep_prob, dtype)
else:
keep_prob = self.cast(self.keep_prob, mstype.float16)
output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob)
out, _ = self.dropout(x)
return out
def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)

Loading…
Cancel
Save