Fix jit.to_static usage (#26682)

revert-26856-strategy_example2
Aurelius84 5 years ago committed by GitHub
parent cb00d50498
commit 67d03bed70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,7 +38,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst'] __all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func'] DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):

@ -17,12 +17,13 @@ from __future__ import print_function
import numpy import numpy
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
@declarative @paddle.jit.to_static
def dyfunc_assert_variable(x): def dyfunc_assert_variable(x):
x_v = fluid.dygraph.to_variable(x) x_v = fluid.dygraph.to_variable(x)
assert x_v assert x_v

@ -15,11 +15,11 @@
import math import math
import numpy as np import numpy as np
import unittest import unittest
from paddle.jit import to_static
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import VARIABLE_FILENAME
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
@ -242,7 +242,7 @@ class BMN(fluid.dygraph.Layer):
param_attr=ParamAttr(name="PEM_2d4_w"), param_attr=ParamAttr(name="PEM_2d4_w"),
bias_attr=ParamAttr(name="PEM_2d4_b")) bias_attr=ParamAttr(name="PEM_2d4_b"))
@declarative @to_static
def forward(self, x): def forward(self, x):
# Base Module # Base Module
x = self.b_conv1(x) x = self.b_conv1(x)

@ -19,7 +19,7 @@ import numpy as np
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative from paddle.jit import to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
PLACE = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( PLACE = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
@ -76,7 +76,7 @@ class MainNetWithDict(fluid.dygraph.Layer):
self.output_size = output_size self.output_size = output_size
self.sub_net = SubNetWithDict(hidden_size, output_size) self.sub_net = SubNetWithDict(hidden_size, output_size)
@declarative @to_static
def forward(self, input, max_len=4): def forward(self, input, max_len=4):
input = fluid.dygraph.to_variable(input) input = fluid.dygraph.to_variable(input)
cache = { cache = {

@ -25,7 +25,6 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import VARIABLE_FILENAME
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
@ -102,7 +101,7 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)), loc=0.0, scale=scale)),
act="softmax") act="softmax")
@declarative @paddle.jit.to_static
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
x = self.inference(inputs) x = self.inference(inputs)
if label is not None: if label is not None:
@ -134,7 +133,7 @@ class TestMNIST(unittest.TestCase):
drop_last=True) drop_last=True)
class TestMNISTWithDeclarative(TestMNIST): class TestMNISTWithToStatic(TestMNIST):
""" """
Tests model if doesn't change the layers while decorated Tests model if doesn't change the layers while decorated
by `dygraph_to_static_output`. In this case, everything should by `dygraph_to_static_output`. In this case, everything should
@ -147,7 +146,7 @@ class TestMNISTWithDeclarative(TestMNIST):
def train_dygraph(self): def train_dygraph(self):
return self.train(to_static=False) return self.train(to_static=False)
def test_mnist_declarative(self): def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph() dygraph_loss = self.train_dygraph()
static_loss = self.train_static() static_loss = self.train_static()
self.assertTrue( self.assertTrue(

Loading…
Cancel
Save