81 lines
2.3 KiB

# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Test mixed layer, projections and operators.
'''
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-4)
din = data_layer(name='test', size=100)
din = embedding_layer(input=din, size=256)
with mixed_layer(size=100) as m1:
m1 += full_matrix_projection(input=din)
with mixed_layer(size=100) as m2:
m2 += table_projection(input=m1)
with mixed_layer(size=100) as m3:
m3 += identity_projection(input=m2)
with mixed_layer(size=100) as m4:
m4 += dotmul_projection(input=m3)
with mixed_layer() as m5:
m5 += context_projection(input=m4, context_len=3)
with mixed_layer() as m6:
m6 += dotmul_operator(a=m3, b=m4)
m6 += scaling_projection(m3)
img = data_layer(name='img', size=32 * 32)
flt = data_layer(name='filter', size=3 * 3 * 1 * 64)
with mixed_layer() as m7:
m7 += conv_operator(
img=img, filter=flt, num_filters=64, num_channels=1, filter_size=3)
m7 += conv_projection(img, filter_size=3, num_filters=64, num_channels=1)
with mixed_layer() as m8:
m8 += conv_operator(
img=img,
filter=flt,
num_filters=64,
num_channels=1,
filter_size=3,
stride=2,
padding=1,
trans=True)
m8 += conv_projection(
img,
filter_size=3,
num_filters=64,
num_channels=1,
stride=2,
padding=1,
trans=True)
end = mixed_layer(
input=[
full_matrix_projection(input=m5),
trans_full_matrix_projection(input=m6),
full_matrix_projection(input=m7), full_matrix_projection(input=m8)
],
size=100,
layer_attr=ExtraAttr(
drop_rate=0.5, error_clipping_threshold=40))
outputs(end)