You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/distributed/service/server.cc

105 lines
3.6 KiB

// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_server_class()) {
LOG(ERROR) << "miss server_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
return NULL;
}
TableManager::instance().initialize();
return server;
}
int32_t PSServer::configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
if (downpour_param.downpour_table_param(i).table_class() ==
"GlobalStepTable") {
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->set_program_env(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map);
}
return initialize();
}
} // namespace distributed
} // namespace paddle