fix VisitVariable

wangkuiyi-patch-2
chengduoZH 7 years ago
parent fbb75c6bb6
commit 035712822c

@ -29,9 +29,7 @@ namespace framework {
namespace details {
struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
public:
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
};
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -18,22 +18,22 @@ namespace paddle {
namespace framework {
namespace details {
template <typename Func>
static void VisitVariable(Variable* var, Func func) {
static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) {
func(var->GetMutable<LoDTensor>());
(*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
func(var->GetMutable<SelectedRows>());
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
}
}
template <typename Func>
static void VisitVariable(const Variable& var, Func func) {
static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) {
func(var.Get<LoDTensor>());
(*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) {
func(var.Get<SelectedRows>());
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
}
@ -56,7 +56,7 @@ struct TensorVisitor {
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor;
VisitVariable(var, vistor);
VisitVariable(var, &vistor);
return *vistor.result_;
}
@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, visitor);
VisitVariable(src, &visitor);
}
} // namespace details

Loading…
Cancel
Save