fix untime fail for output var stop_gradient=True for fusion group (#23317)

revert-23830-2.0-beta
wangchaochaohu 5 years ago committed by GitHub
parent b76f3b2727
commit d085f79228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -99,6 +99,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
@ -106,10 +107,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(
op->Output(name).size(), 1U,
platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type()));
PADDLE_ENFORCE_NE(
var_ids.find(op->Output(name)[0]), var_ids.end(),
platform::errors::InvalidArgument(

@ -111,6 +111,15 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
}
}
}
auto op = n->Op();
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
if (op->Output(name).size() != 1) return false;
}
return true;
}
return false;

@ -109,11 +109,9 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2)
# TODO(wangchaochaohu): support the case when some vars are set
# stop_gradient = True.
self.append_gradients(tmp_3)
self.num_fused_ops = 2
self.fetch_list = [tmp_3]
self.fetch_list = [tmp_3, self.grad(tmp_1)]
class FusionGroupPassTestFP64(FusionGroupPassTest):

Loading…
Cancel
Save