parent
6cbeafb6c0
commit
eeeef957c7
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
distributed_lookup_table,
|
||||||
|
ops::DistributedLookupTableKernel<plat::CUDADeviceContext, float>);
|
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class DistributedLookupTableKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto ids_vars = context.MultiInputVar("Ids");
|
||||||
|
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
|
||||||
|
|
||||||
|
auto id_names = context.InputNames("Ids");
|
||||||
|
auto embedding_name = context.InputNames("W").front();
|
||||||
|
auto out_names = context.OutputNames("Outputs");
|
||||||
|
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
|
||||||
|
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
|
||||||
|
auto is_distributed = context.Attr<bool>("is_distributed");
|
||||||
|
|
||||||
|
operators::distributed::prefetchs(id_names, out_names, embedding_name,
|
||||||
|
is_distributed, lookup_tables, endpoints,
|
||||||
|
context, context.scope());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Distribute CTR model for test fleet api
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import ctr_dataset_reader
|
||||||
|
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
|
||||||
|
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
|
||||||
|
from paddle.distributed.fleet.base.util_factory import fleet_util
|
||||||
|
|
||||||
|
# Fix seed for test
|
||||||
|
fluid.default_startup_program().random_seed = 1
|
||||||
|
fluid.default_main_program().random_seed = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistGpuPsCTR2x2(TestDistCTR2x2):
|
||||||
|
"""
|
||||||
|
For test CTR model, using Fleet api & PS-GPU
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_model_right(self, dirname):
|
||||||
|
model_filename = os.path.join(dirname, "__model__")
|
||||||
|
|
||||||
|
with open(model_filename, "rb") as f:
|
||||||
|
program_desc_str = f.read()
|
||||||
|
|
||||||
|
program = fluid.Program.parse_from_string(program_desc_str)
|
||||||
|
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
|
||||||
|
wn.write(str(program))
|
||||||
|
|
||||||
|
def do_pyreader_training(self, fleet):
|
||||||
|
"""
|
||||||
|
do training using dataset, using fetch handler to catch variable
|
||||||
|
Args:
|
||||||
|
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
|
||||||
|
"""
|
||||||
|
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
|
||||||
|
place = fluid.CUDAPlace(device_id)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
fleet.init_worker()
|
||||||
|
exe.run(fleet.startup_program)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
|
||||||
|
self.reader.decorate_sample_list_generator(train_reader)
|
||||||
|
|
||||||
|
for epoch_id in range(1):
|
||||||
|
self.reader.start()
|
||||||
|
try:
|
||||||
|
pass_start = time.time()
|
||||||
|
while True:
|
||||||
|
loss_val = exe.run(program=fleet.main_program,
|
||||||
|
fetch_list=[self.avg_cost.name])
|
||||||
|
loss_val = np.mean(loss_val)
|
||||||
|
reduce_output = fleet_util.all_reduce(
|
||||||
|
np.array(loss_val), mode="sum")
|
||||||
|
loss_all_trainer = fleet_util.all_gather(float(loss_val))
|
||||||
|
loss_val = float(reduce_output) / len(loss_all_trainer)
|
||||||
|
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
|
||||||
|
loss_val)
|
||||||
|
fleet_util.print_on_rank(message, 0)
|
||||||
|
|
||||||
|
pass_time = time.time() - pass_start
|
||||||
|
except fluid.core.EOFException:
|
||||||
|
self.reader.reset()
|
||||||
|
|
||||||
|
model_dir = tempfile.mkdtemp()
|
||||||
|
fleet.save_inference_model(
|
||||||
|
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
|
||||||
|
self.check_model_right(model_dir)
|
||||||
|
if fleet.is_first_worker():
|
||||||
|
fleet.save_persistables(executor=exe, dirname=model_dir)
|
||||||
|
shutil.rmtree(model_dir)
|
||||||
|
fleet.stop_worker()
|
||||||
|
|
||||||
|
def do_dataset_training(self, fleet):
|
||||||
|
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
|
||||||
|
place = fluid.CUDAPlace(device_id)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
|
||||||
|
fleet.init_worker()
|
||||||
|
exe.run(fleet.startup_program)
|
||||||
|
|
||||||
|
thread_num = 2
|
||||||
|
batch_size = 128
|
||||||
|
filelist = []
|
||||||
|
for _ in range(thread_num):
|
||||||
|
filelist.append(train_file_path)
|
||||||
|
|
||||||
|
# config dataset
|
||||||
|
dataset = paddle.fleet.DatasetFactory().create_dataset()
|
||||||
|
dataset.set_batch_size(batch_size)
|
||||||
|
dataset.set_use_var(self.feeds)
|
||||||
|
pipe_command = 'python ctr_dataset_reader.py'
|
||||||
|
dataset.set_pipe_command(pipe_command)
|
||||||
|
|
||||||
|
dataset.set_filelist(filelist)
|
||||||
|
dataset.set_thread(thread_num)
|
||||||
|
|
||||||
|
for epoch_id in range(1):
|
||||||
|
pass_start = time.time()
|
||||||
|
dataset.set_filelist(filelist)
|
||||||
|
exe.train_from_dataset(
|
||||||
|
program=fleet.main_program,
|
||||||
|
dataset=dataset,
|
||||||
|
fetch_list=[self.avg_cost],
|
||||||
|
fetch_info=["cost"],
|
||||||
|
print_period=2,
|
||||||
|
debug=int(os.getenv("Debug", "0")))
|
||||||
|
pass_time = time.time() - pass_start
|
||||||
|
|
||||||
|
if os.getenv("SAVE_MODEL") == "1":
|
||||||
|
model_dir = tempfile.mkdtemp()
|
||||||
|
fleet.save_inference_model(exe, model_dir,
|
||||||
|
[feed.name for feed in self.feeds],
|
||||||
|
self.avg_cost)
|
||||||
|
self.check_model_right(model_dir)
|
||||||
|
if fleet.is_first_worker():
|
||||||
|
fleet.save_persistables(executor=exe, dirname=model_dir)
|
||||||
|
shutil.rmtree(model_dir)
|
||||||
|
|
||||||
|
fleet.stop_worker()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestDistGpuPsCTR2x2)
|
Loading…
Reference in new issue