fix affine_channel no_need buffer bug, test=develop (#18844)

DDDivano-patch-1
Zeng Jinle 6 years ago committed by GitHub
parent 829ef26281
commit 9a8a7a1ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private:
inline bool HasInput(const std::string& name) const {
auto& inputs = Inputs();
auto iter = inputs.find(name);
if (iter == inputs.end() || iter->second.empty()) {
inline bool HasOutput(const std::string& name) const {
auto& outputs = Outputs();
auto iter = outputs.find(name);
if (iter == outputs.end() || iter->second.empty()) {
return false;
} else {
return iter->second[0] != framework::kEmptyVarName;
@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference
}
public:
std::unordered_set<std::string> operator()() const {
if (!HasInput(framework::GradVarName("Scale")) &&
!HasInput(framework::GradVarName("Bias"))) {
std::unordered_set<std::string> operator()() const override {
if (!HasOutput(framework::GradVarName("Scale")) &&
!HasOutput(framework::GradVarName("Bias"))) {
return {"X"};
} else {
return {};

@ -163,7 +163,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC
set(TEST_OPS_WITH_GC
test_affine_channel_op
test_concat_op
test_elementwise_add_op
test_elementwise_sub_op

Loading…
Cancel
Save