Create tensor in recv op (#7286)

* create tensor in recv op

* static global function to global function
detection_output_fixbug
Yancey 8 years ago committed by GitHub
parent 2d10c75b94
commit aa75f1e2c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {

@ -45,5 +45,7 @@ class Executor {
const platform::Place place_;
};
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
} // namespace framework
} // namespace paddle

@ -19,7 +19,6 @@ limitations under the License. */
#include <unistd.h>
#include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
@ -111,9 +110,11 @@ class RecvOp : public framework::OperatorBase {
<< " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
// create output of merged var.
auto merged_var = recv_scope.Var(grad_var_name);
merged_var->GetMutable<framework::LoDTensor>();
auto *ptr = recv_scope.Var(grad_var_name);
framework::CreateTensor(ptr,
framework::ToVarType(merged_grad->Type()));
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr;
}
if (trainer_count > 1) {

Loading…
Cancel
Save