Support NoNeedBufferVarsInference in dygraph backward (#20868)
* support no need buffer vars in dygraph, test=develop * fix inference compilation error, test=develop * update no_need_buffer_vars_inference, test=develop * add unittests for no_need_buffer_vars_context, test=develop * refine no_need_buffer_vars by return ref, test=develop * polish some codes, test=developcustom_op_abi
parent
bf379fef96
commit
878a40f57d
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
const Attribute &InferNoNeedBufferVarsContext::GetAttr(
|
||||
const std::string &name) const {
|
||||
auto iter = attrs_.find(name);
|
||||
PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s",
|
||||
name);
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
StaticGraphInferNoNeedBufferVarsContext::
|
||||
StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
|
||||
const VariableNameMap &outputs,
|
||||
const AttributeMap &attrs)
|
||||
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
|
||||
|
||||
bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
|
||||
const std::string &slot) const {
|
||||
auto iter = outputs_.find(slot);
|
||||
if (iter != outputs_.end()) {
|
||||
for (auto &var : iter->second) {
|
||||
if (var != kEmptyVarName) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
|
||||
const imperative::NameVarBaseMap &inputs,
|
||||
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs)
|
||||
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
|
||||
|
||||
bool DyGraphInferNoNeedBufferVarsContext::HasOutput(
|
||||
const std::string &slot) const {
|
||||
auto iter = outputs_.find(slot);
|
||||
if (iter != outputs_.end()) {
|
||||
for (auto &var : iter->second) {
|
||||
if (var) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
TEST(test_no_need_buffer_vars_inference, test_static_graph) {
|
||||
AttributeMap attrs{{"is_test", true}};
|
||||
VariableNameMap inputs;
|
||||
VariableNameMap outputs{{"Out", {kEmptyVarName, "tmp_0"}}};
|
||||
|
||||
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
|
||||
|
||||
ASSERT_TRUE(ctx.HasOutput("Out"));
|
||||
ASSERT_FALSE(ctx.HasOutput("X"));
|
||||
|
||||
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
|
||||
}
|
||||
|
||||
TEST(test_no_need_buffer_vars_inference, test_dygraph) {
|
||||
AttributeMap attrs{{"is_test", true}};
|
||||
imperative::NameVarBaseMap inputs;
|
||||
imperative::NameVarBaseMap outputs;
|
||||
outputs["Out"].emplace_back(nullptr);
|
||||
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0"));
|
||||
|
||||
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
|
||||
|
||||
ASSERT_TRUE(ctx.HasOutput("Out"));
|
||||
ASSERT_FALSE(ctx.HasOutput("X"));
|
||||
|
||||
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue