|
|
|
@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference
|
|
|
|
|
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
inline bool HasInput(const std::string& name) const {
|
|
|
|
|
auto& inputs = Inputs();
|
|
|
|
|
auto iter = inputs.find(name);
|
|
|
|
|
if (iter == inputs.end() || iter->second.empty()) {
|
|
|
|
|
inline bool HasOutput(const std::string& name) const {
|
|
|
|
|
auto& outputs = Outputs();
|
|
|
|
|
auto iter = outputs.find(name);
|
|
|
|
|
if (iter == outputs.end() || iter->second.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
} else {
|
|
|
|
|
return iter->second[0] != framework::kEmptyVarName;
|
|
|
|
@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
std::unordered_set<std::string> operator()() const {
|
|
|
|
|
if (!HasInput(framework::GradVarName("Scale")) &&
|
|
|
|
|
!HasInput(framework::GradVarName("Bias"))) {
|
|
|
|
|
std::unordered_set<std::string> operator()() const override {
|
|
|
|
|
if (!HasOutput(framework::GradVarName("Scale")) &&
|
|
|
|
|
!HasOutput(framework::GradVarName("Bias"))) {
|
|
|
|
|
return {"X"};
|
|
|
|
|
} else {
|
|
|
|
|
return {};
|
|
|
|
|