|
|
|
@ -8,20 +8,22 @@ class TestCrossEntropy(OpTest):
|
|
|
|
|
self.op_type = "onehot_cross_entropy"
|
|
|
|
|
batch_size = 30
|
|
|
|
|
class_num = 10
|
|
|
|
|
|
|
|
|
|
X = numpy.random.uniform(0.1, 1.0,
|
|
|
|
|
[batch_size, class_num]).astype("float32")
|
|
|
|
|
label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
|
|
|
|
|
self.inputs = {'X': X, 'label': label}
|
|
|
|
|
Y = []
|
|
|
|
|
for i in range(0, batch_size):
|
|
|
|
|
Y.append(-numpy.log(X[i][label[i]]))
|
|
|
|
|
self.outputs = {'Y': numpy.array(Y).astype("float32")}
|
|
|
|
|
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
|
|
|
|
|
|
|
|
|
|
cross_entropy = numpy.asmatrix(
|
|
|
|
|
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
|
|
|
|
|
dtype="float32")
|
|
|
|
|
self.inputs = {"X": X, "label": labels}
|
|
|
|
|
self.outputs = {"Y": cross_entropy}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(['X'], 'Y')
|
|
|
|
|
self.check_grad(["X"], "Y")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|