update nce and hierarchical_sigmoid remote_prefetch

test=develop
revert-16555-model_data_cryption_link_all_lib
Qiao Longfei 6 years ago
parent a1821a0449
commit df45c8c538

@ -81,8 +81,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} else if (node->Name() == "lookup_table") {
VLOG(0) << "set lookup_table op remote_prefetch to false";
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
VLOG(0) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
}
}

@ -68,8 +68,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch
auto remote_prefetch = ctx.Attr<bool>("remote_prefetch");
auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server

@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel<T> {
auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
// for remote prefetch
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto epmap = context.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server

Loading…
Cancel
Save