|
|
|
@ -33,7 +33,8 @@ using DDim = framework::DDim;
|
|
|
|
|
|
|
|
|
|
void CondOp::CreateScope(const Scope& scope) const {
|
|
|
|
|
auto sub_scopes_var = scope.FindVar("SubScopes");
|
|
|
|
|
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
|
|
|
|
|
"Output(SubScopes) of CondOp should not be null.");
|
|
|
|
|
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
auto& sub_scope = scope.NewScope();
|
|
|
|
|
sub_scopes->push_back(&sub_scope);
|
|
|
|
@ -41,7 +42,8 @@ void CondOp::CreateScope(const Scope& scope) const {
|
|
|
|
|
|
|
|
|
|
void CondOp::CreateIndexTensor(const Scope& scope) const {
|
|
|
|
|
auto index_tensors_var = scope.FindVar("IndexTensors");
|
|
|
|
|
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(index_tensors_var,
|
|
|
|
|
"Output(IndexTensors) of CondOp should not be null.");
|
|
|
|
|
auto& index_tensors =
|
|
|
|
|
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
|
|
|
|
|
index_tensors.push_back(LoDTensor());
|
|
|
|
@ -49,7 +51,8 @@ void CondOp::CreateIndexTensor(const Scope& scope) const {
|
|
|
|
|
|
|
|
|
|
void CondOp::InferShape(const Scope& scope) const {
|
|
|
|
|
auto sub_scopes_var = scope.FindVar("SubScopes");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
|
|
|
|
|
"Output(SubScopes) of CondOp should not be null.");
|
|
|
|
|
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
@ -63,7 +66,8 @@ void CondOp::InferShape(const Scope& scope) const {
|
|
|
|
|
// branch
|
|
|
|
|
CreateIndexTensor(scope);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
|
|
|
|
|
PADDLE_ENFORCE(!Inputs("Xs").empty(),
|
|
|
|
|
"Inputs(Xs) of CondOp can't be empty.");
|
|
|
|
|
for (auto& input : Inputs("Xs")) {
|
|
|
|
|
// Create a new tensor in sub-scope for input-type tensor
|
|
|
|
|
Variable* v = sub_scopes[i]->NewVar(input);
|
|
|
|
@ -108,13 +112,18 @@ void CondOp::InferShape(const Scope& scope) const {
|
|
|
|
|
void CondOp::Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto* sub_scopes_var = scope.FindVar("SubScopes");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
|
|
|
|
|
"Output(SubScopes) of CondOp should not be null.");
|
|
|
|
|
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
|
|
|
|
|
auto* index_tensors_var = scope.FindVar("IndexTensors");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(index_tensors_var,
|
|
|
|
|
"Output(IndexTensors) of CondOp should not be null.");
|
|
|
|
|
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
|
|
|
|
|
|
|
|
|
|
std::string cond_name = Input("Cond");
|
|
|
|
|
Variable* cond_var = scope.FindVar(cond_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(cond_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(cond_var,
|
|
|
|
|
"Input(Cond) of CondOp should not be null.");
|
|
|
|
|
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// Step 1: get the true/false index at runtime
|
|
|
|
@ -171,6 +180,8 @@ void CondOp::Run(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 4: merge output results
|
|
|
|
|
PADDLE_ENFORCE(!Outputs("Outs").empty(),
|
|
|
|
|
"Outputs(Outs) of CondOp can't be empty.");
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
// i= 0/i for True and False branches respectively
|
|
|
|
|
for (auto& output : Outputs("Outs")) {
|
|
|
|
|