|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/device_worker.h"
|
|
|
|
|
#include "paddle/fluid/framework/device_worker_factory.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_helper.h"
|
|
|
|
@ -20,7 +21,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
void HogwildWorker::Initialize(const TrainerDesc &desc) {
|
|
|
|
|
fetch_config_ = desc.fetch_config();
|
|
|
|
|
param_ = desc.hogwild_param();
|
|
|
|
|
skip_ops_.resize(param_.skip_ops_size());
|
|
|
|
@ -30,45 +31,70 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
use_cvm_ = desc.use_cvm();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::CreateThreadOperators(const ProgramDesc& program) {
|
|
|
|
|
auto& block = program.Block(0);
|
|
|
|
|
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
|
|
|
|
|
auto &block = program.Block(0);
|
|
|
|
|
op_names_.clear();
|
|
|
|
|
for (auto& op_desc : block.AllOps()) {
|
|
|
|
|
for (auto &op_desc : block.AllOps()) {
|
|
|
|
|
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
|
|
|
|
|
op_names_.push_back(op_desc->Type());
|
|
|
|
|
OperatorBase* local_op_ptr = local_op.release();
|
|
|
|
|
OperatorBase *local_op_ptr = local_op.release();
|
|
|
|
|
ops_.push_back(local_op_ptr);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
|
|
|
|
|
auto& block = program.Block(0);
|
|
|
|
|
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
|
|
|
|
|
auto &block = program.Block(0);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
root_scope_, "root_scope should be set before creating thread scope");
|
|
|
|
|
|
|
|
|
|
thread_scope_ = &root_scope_->NewScope();
|
|
|
|
|
for (auto& var : block.AllVars()) {
|
|
|
|
|
|
|
|
|
|
for (auto &var : block.AllVars()) {
|
|
|
|
|
if (var->Persistable()) {
|
|
|
|
|
auto* ptr = root_scope_->Var(var->Name());
|
|
|
|
|
auto *ptr = root_scope_->Var(var->Name());
|
|
|
|
|
InitializeVariable(ptr, var->GetType());
|
|
|
|
|
if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
|
|
|
|
|
thread_id_ != 0) {
|
|
|
|
|
int tensor_dim =
|
|
|
|
|
root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
|
|
|
|
|
auto *ptr1 = thread_scope_->Var(var->Name());
|
|
|
|
|
InitializeVariable(ptr1, var->GetType());
|
|
|
|
|
LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
|
|
|
|
|
LoDTensor *root_tensor =
|
|
|
|
|
root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
#define MemsetCallback(cpp_type, proto_type) \
|
|
|
|
|
do { \
|
|
|
|
|
if (root_tensor->type() == proto_type) { \
|
|
|
|
|
SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
_ForEachDataType_(MemsetCallback);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto* ptr = thread_scope_->Var(var->Name());
|
|
|
|
|
auto *ptr = thread_scope_->Var(var->Name());
|
|
|
|
|
InitializeVariable(ptr, var->GetType());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
|
|
|
|
|
int tensor_dim) {
|
|
|
|
|
T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
|
|
|
|
|
memset(ptr, 0, sizeof(T) * tensor_dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::BindingDataFeedMemory() {
|
|
|
|
|
const std::vector<std::string>& input_feed =
|
|
|
|
|
const std::vector<std::string> &input_feed =
|
|
|
|
|
device_reader_->GetUseSlotAlias();
|
|
|
|
|
for (auto name : input_feed) {
|
|
|
|
|
device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
|
|
|
|
|
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
|
|
|
|
|
CreateThreadScope(main_prog);
|
|
|
|
|
CreateThreadOperators(main_prog);
|
|
|
|
|
}
|
|
|
|
@ -78,7 +104,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
|
|
|
|
|
device_reader_->Start();
|
|
|
|
|
std::vector<double> op_total_time;
|
|
|
|
|
std::vector<std::string> op_name;
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
op_name.push_back(op->Type());
|
|
|
|
|
}
|
|
|
|
|
op_total_time.resize(ops_.size());
|
|
|
|
@ -141,7 +167,7 @@ void HogwildWorker::TrainFiles() {
|
|
|
|
|
device_reader_->Start();
|
|
|
|
|
int cur_batch;
|
|
|
|
|
while ((cur_batch = device_reader_->Next()) > 0) {
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
bool need_skip = false;
|
|
|
|
|
for (auto t = 0u; t < skip_ops_.size(); ++t) {
|
|
|
|
|
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
|
|
|
|
|