parent
6cbeafb6c0
commit
eeeef957c7
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
distributed_lookup_table,
|
||||
ops::DistributedLookupTableKernel<plat::CUDADeviceContext, float>);
|
@ -0,0 +1,45 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DistributedLookupTableKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto ids_vars = context.MultiInputVar("Ids");
|
||||
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
|
||||
|
||||
auto id_names = context.InputNames("Ids");
|
||||
auto embedding_name = context.InputNames("W").front();
|
||||
auto out_names = context.OutputNames("Outputs");
|
||||
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
|
||||
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
|
||||
auto is_distributed = context.Attr<bool>("is_distributed");
|
||||
|
||||
operators::distributed::prefetchs(id_names, out_names, embedding_name,
|
||||
is_distributed, lookup_tables, endpoints,
|
||||
context, context.scope());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Distribute CTR model for test fleet api
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import ctr_dataset_reader
|
||||
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
|
||||
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
|
||||
from paddle.distributed.fleet.base.util_factory import fleet_util
|
||||
|
||||
# Fix seed for test
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
|
||||
class TestDistGpuPsCTR2x2(TestDistCTR2x2):
|
||||
"""
|
||||
For test CTR model, using Fleet api & PS-GPU
|
||||
"""
|
||||
|
||||
def check_model_right(self, dirname):
|
||||
model_filename = os.path.join(dirname, "__model__")
|
||||
|
||||
with open(model_filename, "rb") as f:
|
||||
program_desc_str = f.read()
|
||||
|
||||
program = fluid.Program.parse_from_string(program_desc_str)
|
||||
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
|
||||
wn.write(str(program))
|
||||
|
||||
def do_pyreader_training(self, fleet):
|
||||
"""
|
||||
do training using dataset, using fetch handler to catch variable
|
||||
Args:
|
||||
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
|
||||
"""
|
||||
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
|
||||
place = fluid.CUDAPlace(device_id)
|
||||
exe = fluid.Executor(place)
|
||||
fleet.init_worker()
|
||||
exe.run(fleet.startup_program)
|
||||
|
||||
batch_size = 4
|
||||
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
|
||||
self.reader.decorate_sample_list_generator(train_reader)
|
||||
|
||||
for epoch_id in range(1):
|
||||
self.reader.start()
|
||||
try:
|
||||
pass_start = time.time()
|
||||
while True:
|
||||
loss_val = exe.run(program=fleet.main_program,
|
||||
fetch_list=[self.avg_cost.name])
|
||||
loss_val = np.mean(loss_val)
|
||||
reduce_output = fleet_util.all_reduce(
|
||||
np.array(loss_val), mode="sum")
|
||||
loss_all_trainer = fleet_util.all_gather(float(loss_val))
|
||||
loss_val = float(reduce_output) / len(loss_all_trainer)
|
||||
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
|
||||
loss_val)
|
||||
fleet_util.print_on_rank(message, 0)
|
||||
|
||||
pass_time = time.time() - pass_start
|
||||
except fluid.core.EOFException:
|
||||
self.reader.reset()
|
||||
|
||||
model_dir = tempfile.mkdtemp()
|
||||
fleet.save_inference_model(
|
||||
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
|
||||
self.check_model_right(model_dir)
|
||||
if fleet.is_first_worker():
|
||||
fleet.save_persistables(executor=exe, dirname=model_dir)
|
||||
shutil.rmtree(model_dir)
|
||||
fleet.stop_worker()
|
||||
|
||||
def do_dataset_training(self, fleet):
|
||||
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
|
||||
)
|
||||
|
||||
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
|
||||
place = fluid.CUDAPlace(device_id)
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
fleet.init_worker()
|
||||
exe.run(fleet.startup_program)
|
||||
|
||||
thread_num = 2
|
||||
batch_size = 128
|
||||
filelist = []
|
||||
for _ in range(thread_num):
|
||||
filelist.append(train_file_path)
|
||||
|
||||
# config dataset
|
||||
dataset = paddle.fleet.DatasetFactory().create_dataset()
|
||||
dataset.set_batch_size(batch_size)
|
||||
dataset.set_use_var(self.feeds)
|
||||
pipe_command = 'python ctr_dataset_reader.py'
|
||||
dataset.set_pipe_command(pipe_command)
|
||||
|
||||
dataset.set_filelist(filelist)
|
||||
dataset.set_thread(thread_num)
|
||||
|
||||
for epoch_id in range(1):
|
||||
pass_start = time.time()
|
||||
dataset.set_filelist(filelist)
|
||||
exe.train_from_dataset(
|
||||
program=fleet.main_program,
|
||||
dataset=dataset,
|
||||
fetch_list=[self.avg_cost],
|
||||
fetch_info=["cost"],
|
||||
print_period=2,
|
||||
debug=int(os.getenv("Debug", "0")))
|
||||
pass_time = time.time() - pass_start
|
||||
|
||||
if os.getenv("SAVE_MODEL") == "1":
|
||||
model_dir = tempfile.mkdtemp()
|
||||
fleet.save_inference_model(exe, model_dir,
|
||||
[feed.name for feed in self.feeds],
|
||||
self.avg_cost)
|
||||
self.check_model_right(model_dir)
|
||||
if fleet.is_first_worker():
|
||||
fleet.save_persistables(executor=exe, dirname=model_dir)
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
fleet.stop_worker()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestDistGpuPsCTR2x2)
|
Loading…
Reference in new issue