|
|
|
@ -879,29 +879,36 @@ class TestRemoteNce(TestDistLookupTableBase):
|
|
|
|
|
class TestRemoteHsigmoid(TestDistLookupTableBase):
|
|
|
|
|
def network_with_table(self, is_sparse, is_distributed):
|
|
|
|
|
|
|
|
|
|
num_total_classes = 10
|
|
|
|
|
num_total_classes = 3
|
|
|
|
|
|
|
|
|
|
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
|
|
|
|
|
input = fluid.layers.data(name="input", shape=[1], dtype="float32")
|
|
|
|
|
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
|
|
|
|
path_table = fluid.layers.data(
|
|
|
|
|
name='path_table', shape=[10], dtype='int64')
|
|
|
|
|
name='path_table', shape=[3], dtype='int64')
|
|
|
|
|
path_code = fluid.layers.data(
|
|
|
|
|
name='path_code', shape=[10], dtype='int64')
|
|
|
|
|
name='path_code', shape=[3], dtype='int64')
|
|
|
|
|
w_param = fluid.default_main_program().global_block().create_parameter(
|
|
|
|
|
shape=[num_total_classes, 10],
|
|
|
|
|
dtype='float32',
|
|
|
|
|
name='hs_w',
|
|
|
|
|
initializer=fluid.initializer.ConstantInitializer())
|
|
|
|
|
b_param = fluid.default_main_program().global_block().create_parameter(
|
|
|
|
|
shape=[num_total_classes, 1],
|
|
|
|
|
shape=[3, 1],
|
|
|
|
|
dtype='float32',
|
|
|
|
|
name='hs_b',
|
|
|
|
|
initializer=fluid.initializer.ConstantInitializer())
|
|
|
|
|
|
|
|
|
|
cost = fluid.layers.hsigmoid(
|
|
|
|
|
emb = fluid.layers.embedding(
|
|
|
|
|
input=input,
|
|
|
|
|
is_sparse=is_sparse,
|
|
|
|
|
size=[3, 3],
|
|
|
|
|
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
|
|
|
|
|
scale=1 / math.sqrt(num_total_classes))))
|
|
|
|
|
|
|
|
|
|
cost = fluid.layers.hsigmoid(
|
|
|
|
|
input=emb,
|
|
|
|
|
label=label,
|
|
|
|
|
num_classes=non_leaf_num,
|
|
|
|
|
num_classes=num_total_classes,
|
|
|
|
|
path_table=path_table,
|
|
|
|
|
path_code=path_code,
|
|
|
|
|
is_custom=True,
|
|
|
|
@ -918,9 +925,29 @@ class TestRemoteHsigmoid(TestDistLookupTableBase):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
params_to_check = list()
|
|
|
|
|
for op in trainer.blocks[0].ops:
|
|
|
|
|
if op.type == "recv":
|
|
|
|
|
if op.type == "hierarchical_sigmoid":
|
|
|
|
|
params_to_check = [op.input("W")[0], op.input("Bias")[0]]
|
|
|
|
|
for name in ["epmap", "table_names", "epmap"]:
|
|
|
|
|
assert op.has_attr(name)
|
|
|
|
|
if name == "epmap":
|
|
|
|
|
assert op.attr(name)[0] == u'127.0.0.1:6174'
|
|
|
|
|
elif name == "table_names":
|
|
|
|
|
assert op.attr(name)[0] == u'hierarchical_sigmoid_0.w_0'
|
|
|
|
|
else:
|
|
|
|
|
assert op.attr(name) == 3
|
|
|
|
|
elif op.type == "lookup_table":
|
|
|
|
|
params_to_check.append(op.input("W")[0])
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
op_count = 0
|
|
|
|
|
for op in trainer.blocks[0].ops:
|
|
|
|
|
if op.type == "recv":
|
|
|
|
|
assert len(op.output("Out")) == 1
|
|
|
|
|
assert op.output("Out")[0] == u'hierarchical_sigmoid_0.b_0'
|
|
|
|
|
op_count += 1
|
|
|
|
|
assert op_count == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|