fix dump, fix cvm check (#25400)

* fix dump, fix cvm check
test=develop

* fix
test=develop

* fix
test=develop

* fix
test=develop
revert-24895-update_cub
xujiaqi01 5 years ago committed by GitHub
parent 8ebffc78c9
commit d11c140e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) {
writer_ << os.str();
}
}
void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) {
bool enable_random_dump = desc.enable_random_dump();
if (!enable_random_dump) {

@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program,
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
pull_dense_worker_->SetRootScope(root_scope_);
@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() {
}
}
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
pull_dense_worker_->Stop();

@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) {
dataset->SetTrainerNum(1);
dataset->SetDataFeedDesc(str);
dataset->CreateReaders();
Scope root_scope;
tmp1->SetScope(&root_scope);
tmp1->Initialize(t, dataset.get());
ProgramDesc p;
tmp1->InitOtherEnv(p);
tmp1->Finalize();
#endif
}
} // namespace framework

@ -106,7 +106,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
@ -133,7 +133,7 @@ void MultiTrainer::Run() {
}
void MultiTrainer::Finalize() {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();

@ -22,6 +22,8 @@ void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) {
dump_fields_path_ = desc.dump_fields_path();
need_dump_field_ = false;
need_dump_param_ = false;
if (dump_fields_path_ == "") {
VLOG(2) << "dump_fields_path_ is empty";
return;

@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM");
OP_INOUT_CHECK(ctx->HasInput("CVM"), "Input", "CVM", "CVM");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(X)'s rank should be 2."));
PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2UL,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
"The 2nd dimension of "
"Input(CVM) should be 2."));
if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});

@ -62,15 +62,18 @@ class TrainerFactory(object):
trainer._set_mpi_rank(opt_info["mpi_rank"])
if opt_info.get("mpi_size") is not None:
trainer._set_mpi_size(opt_info["mpi_size"])
if opt_info.get("dump_fields") is not None:
if opt_info.get("dump_fields") is not None and len(
opt_info.get("dump_fields")) != 0:
trainer._set_dump_fields(opt_info["dump_fields"])
if opt_info.get("dump_fields_path") is not None:
if opt_info.get("dump_fields_path") is not None and len(
opt_info.get("dump_fields_path")) != 0:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("dump_file_num") is not None:
trainer._set_dump_file_num(opt_info["dump_file_num"])
if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"])
if opt_info.get("dump_param") is not None:
if opt_info.get("dump_param") is not None and len(
opt_info.get("dump_param")) != 0:
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("enable_random_dump") is not None:
trainer._set_enable_random_dump(opt_info[

Loading…
Cancel
Save