Fix function in fit-a-line with new API (#11020)

wangkuiyi-patch-1
Siddharth Goyal 7 years ago committed by GitHub
parent fae3d8d2dc
commit 52e2eb65b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,7 +38,7 @@ def inference_program():
return y_predict
def linear():
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = inference_program()
@ -104,7 +104,7 @@ def main(use_cuda):
# Directory for saving the trained model
params_dirname = "fit_a_line.inference.model"
train(use_cuda, linear, params_dirname)
train(use_cuda, train_program, params_dirname)
infer(use_cuda, inference_program, params_dirname)

Loading…
Cancel
Save