test=develop, fix hsigmoid dereference nullptr (#16769)

* test=develop, fix hsigmoid dereference nullptr

* test=develop, refine condition

* test=develop, refine comments
revert-16839-cmakelist_change
Jiabin Yang 6 years ago committed by liuwei1031
parent 19bb53fa61
commit 84b7a7291e

@ -238,6 +238,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0)); zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in); bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else { } else {
PADDLE_ENFORCE(path != nullptr,
"Sparse mode should not be used without custom tree!");
framework::Vector<int64_t> real_rows = PathToRows(*path); framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));

@ -5721,12 +5721,21 @@ def hsigmoid(input,
raise ValueError( raise ValueError(
"num_classes must not be less than 2 with default tree") "num_classes must not be less than 2 with default tree")
if (not is_custom) and (is_sparse):
print("Sparse mode should not be used without custom tree")
is_sparse = False
if (not is_custom) and ((path_table is not None) or
(path_code is not None)):
raise ValueError(
"only num_classes should be passed without custom tree")
if (is_custom) and (path_code is None): if (is_custom) and (path_code is None):
raise ValueError("path_code should not be None with costum tree") raise ValueError("path_code should not be None with custom tree")
elif (is_custom) and (path_table is None): elif (is_custom) and (path_table is None):
raise ValueError("path_table should not be None with costum tree") raise ValueError("path_table should not be None with custom tree")
elif (is_custom) and (num_classes is None): elif (is_custom) and (num_classes is None):
raise ValueError("num_classes should not be None with costum tree") raise ValueError("num_classes should not be None with custom tree")
else: else:
pass pass

Loading…
Cancel
Save