|
|
|
|
@ -230,14 +230,14 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
auto attr = op_desc.GetAttr("is_sparse");
|
|
|
|
|
bool is_sparse = boost::get<bool>(attr);
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
|
|
|
|
|
<< " is set to SelectedRows";
|
|
|
|
|
VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
|
|
|
|
|
<< " is set to SelectedRows";
|
|
|
|
|
block->Var(weight_grad)
|
|
|
|
|
->SetType(framework::proto::VarType::SELECTED_ROWS);
|
|
|
|
|
block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
|
|
|
|
|
<< " is set to LoDTensor";
|
|
|
|
|
VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
|
|
|
|
|
<< " is set to LoDTensor";
|
|
|
|
|
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
}
|
|
|
|
|
|