Avoid crash when calling ctx->HasInputs and add the check of shape in fill_copnstant op. (#23698)

revert-23830-2.0-beta
Yiqun Liu 5 years ago committed by GitHub
parent ac4da77aa6
commit 9e85d02373
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -721,6 +721,9 @@ CompileTimeInferShapeContext::CompileTimeInferShapeContext(
: op_(op), block_(block) {}
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
}
const std::vector<std::string> &input_names = op_.Input(name);
auto length = input_names.size();
if (length == 0) {
@ -734,6 +737,9 @@ bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false;
}
const std::vector<std::string> &output_names = op_.Output(name);
auto length = output_names.size();
if (length == 0) {
@ -747,6 +753,9 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
}
const std::vector<std::string> &input_names = op_.Input(name);
if (input_names.empty()) {
return false;
@ -758,6 +767,9 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
}
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false;
}
const std::vector<std::string> &output_names = op_.Output(name);
if (output_names.empty()) {
return false;

@ -25,6 +25,16 @@ class FillConstantOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FillConstant");
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE_GE(
shape[i], 0,
platform::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be greater "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i, shape[i], framework::make_ddim(shape)));
}
}
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");

@ -369,11 +369,11 @@ class SeqPGAgent(object):
self.probs, self.samples, self.sample_length = self.model(
source, source_length, target, target_length)
self.samples.stop_gradient = True
self.reward = fluid.layers.create_global_var(
self.reward = fluid.data(
name="reward",
shape=[-1, -1], # batch_size, seq_len
value="1",
shape=[None, None], # batch_size, seq_len
dtype=self.probs.dtype)
self.samples.stop_gradient = False
self.cost = self.alg.learn(self.probs, self.samples, self.reward,
self.sample_length)

Loading…
Cancel
Save