|
|
|
@ -18,22 +18,22 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
template <typename Func>
|
|
|
|
|
static void VisitVariable(Variable* var, Func func) {
|
|
|
|
|
static void VisitVariable(Variable* var, Func* func) {
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
func(var->GetMutable<LoDTensor>());
|
|
|
|
|
(*func)(var->GetMutable<LoDTensor>());
|
|
|
|
|
} else if (var->IsType<SelectedRows>()) {
|
|
|
|
|
func(var->GetMutable<SelectedRows>());
|
|
|
|
|
(*func)(var->GetMutable<SelectedRows>());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported type %s", var->Type().name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Func>
|
|
|
|
|
static void VisitVariable(const Variable& var, Func func) {
|
|
|
|
|
static void VisitVariable(const Variable& var, Func* func) {
|
|
|
|
|
if (var.IsType<LoDTensor>()) {
|
|
|
|
|
func(var.Get<LoDTensor>());
|
|
|
|
|
(*func)(var.Get<LoDTensor>());
|
|
|
|
|
} else if (var.IsType<SelectedRows>()) {
|
|
|
|
|
func(var.Get<SelectedRows>());
|
|
|
|
|
(*func)(var.Get<SelectedRows>());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported type %s", var.Type().name());
|
|
|
|
|
}
|
|
|
|
@ -56,7 +56,7 @@ struct TensorVisitor {
|
|
|
|
|
|
|
|
|
|
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
|
|
|
|
|
TensorVisitor vistor;
|
|
|
|
|
VisitVariable(var, vistor);
|
|
|
|
|
VisitVariable(var, &vistor);
|
|
|
|
|
return *vistor.result_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {
|
|
|
|
|
|
|
|
|
|
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
|
|
|
|
|
ShareDimsAndLoDVisitor visitor{trg};
|
|
|
|
|
VisitVariable(src, visitor);
|
|
|
|
|
VisitVariable(src, &visitor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|