parent
e206e712f3
commit
a941422337
@ -0,0 +1,67 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation for DQN"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.config import config_dqn as cfg
|
||||
from src.agent import Agent
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore dqn Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None, help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
args = parser.parse_args()
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
env = gym.make('CartPole-v0')
|
||||
cfg.state_space_dim = env.observation_space.shape[0]
|
||||
cfg.action_space_dim = env.action_space.n
|
||||
agent = Agent(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
not_load_param = load_param_into_net(agent.policy_net, param_dict)
|
||||
if not_load_param:
|
||||
raise ValueError("Load param into net fail!")
|
||||
|
||||
score = 0
|
||||
agent.load_dict()
|
||||
for episode in range(50):
|
||||
s0 = env.reset()
|
||||
total_reward = 1
|
||||
while True:
|
||||
a0 = agent.eval_act(s0)
|
||||
s1, r1, done, _ = env.step(a0)
|
||||
|
||||
if done:
|
||||
r1 = -1
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
total_reward += r1
|
||||
s0 = s1
|
||||
score += total_reward
|
||||
print("episode", episode, "total_reward", total_reward)
|
||||
print("mean_reward", score/50)
|
@ -0,0 +1 @@
|
||||
gym
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
CKPT_PATH=$1
|
||||
python -s ${self_path}/../eval.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
CKPT_PATH=$1
|
||||
python -s ${self_path}/../eval.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
CKPT_PATH=$1
|
||||
python -s ${self_path}/../train.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
CKPT_PATH=$1
|
||||
python -s ${self_path}/../train.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Agent of reinforcement learning network"""
|
||||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.dqn import DQN, WithLossCell
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
DQN Agent
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim)
|
||||
self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim)
|
||||
self.optimizer = nn.RMSProp(self.policy_net.trainable_params(), learning_rate=self.lr)
|
||||
loss_fn = nn.MSELoss()
|
||||
loss_q_net = WithLossCell(self.policy_net, loss_fn)
|
||||
self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer)
|
||||
self.policy_net_train.set_train(mode=True)
|
||||
self.buffer = []
|
||||
self.steps = 0
|
||||
|
||||
def act(self, s0):
|
||||
"""
|
||||
Agent choose action.
|
||||
"""
|
||||
self.steps += 1
|
||||
epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay))
|
||||
if random.random() < epsi:
|
||||
a0 = random.randrange(self.action_space_dim)
|
||||
else:
|
||||
s0 = np.expand_dims(s0, axis=0)
|
||||
s0 = Tensor(s0, mstype.float32)
|
||||
a0 = self.policy_net(s0).asnumpy()
|
||||
a0 = np.argmax(a0)
|
||||
return a0
|
||||
|
||||
def eval_act(self, s0):
|
||||
self.steps += 1
|
||||
s0 = np.expand_dims(s0, axis=0)
|
||||
s0 = Tensor(s0, mstype.float32)
|
||||
a0 = self.policy_net(s0).asnumpy()
|
||||
a0 = np.argmax(a0)
|
||||
return a0
|
||||
|
||||
def put(self, *transition):
|
||||
if len(self.buffer) == self.capacity:
|
||||
self.buffer.pop(0)
|
||||
self.buffer.append(transition)
|
||||
|
||||
def load_dict(self):
|
||||
for target_item, source_item in zip(self.target_net.parameters_dict(), self.policy_net.parameters_dict()):
|
||||
target_param = self.target_net.parameters_dict()[target_item]
|
||||
source_param = self.policy_net.parameters_dict()[source_item]
|
||||
target_param.set_data(source_param.data)
|
||||
|
||||
def learn(self):
|
||||
"""
|
||||
Agent learn from experience data.
|
||||
"""
|
||||
if (len(self.buffer)) < self.batch_size:
|
||||
return
|
||||
|
||||
samples = random.sample(self.buffer, self.batch_size)
|
||||
s0, a0, r1, s1 = zip(*samples)
|
||||
s1 = Tensor(s1, mstype.float32)
|
||||
s0 = Tensor(s0, mstype.float32)
|
||||
a0 = Tensor(np.expand_dims(a0, axis=1))
|
||||
next_state_values = self.target_net(s1).asnumpy()
|
||||
next_state_values = np.max(next_state_values, axis=1)
|
||||
|
||||
y_true = r1 + self.gamma * next_state_values
|
||||
y_true = Tensor(np.expand_dims(y_true, axis=1), mstype.float32)
|
||||
self.policy_net_train(s0, a0, y_true)
|
@ -0,0 +1,31 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config_dqn = edict({
|
||||
'gamma': 0.8,
|
||||
'epsi_high': 0.9,
|
||||
'epsi_low': 0.05,
|
||||
'decay': 200,
|
||||
'lr': 0.001,
|
||||
'capacity': 100000,
|
||||
'batch_size': 512,
|
||||
'state_space_dim': 4,
|
||||
'action_space_dim': 2
|
||||
})
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DQN net"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class DQN(nn. Cell):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super(DQN, self).__init__()
|
||||
self.linear1 = nn.Dense(input_size, hidden_size)
|
||||
self.linear2 = nn.Dense(hidden_size, output_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.linear1(x))
|
||||
return self.linear2(x)
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
network with loss function
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
self.gather = ops.GatherD()
|
||||
|
||||
def construct(self, x, act, label):
|
||||
out = self._backbone(x)
|
||||
out = self.gather(out, 1, act)
|
||||
loss = self._loss_fn(out, label)
|
||||
return loss
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train DQN and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import gym
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.config import config_dqn as cfg
|
||||
from src.agent import Agent
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore dqn Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
args = parser.parse_args()
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
env = gym.make('CartPole-v0')
|
||||
cfg.state_space_dim = env.observation_space.shape[0]
|
||||
cfg.action_space_dim = env.action_space.n
|
||||
agent = Agent(**cfg)
|
||||
agent.load_dict()
|
||||
|
||||
for episode in range(150):
|
||||
s0 = env.reset()
|
||||
total_reward = 1
|
||||
while True:
|
||||
a0 = agent.act(s0)
|
||||
s1, r1, done, _ = env.step(a0)
|
||||
|
||||
if done:
|
||||
r1 = -1
|
||||
|
||||
agent.put(s0, a0, r1, s1)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
total_reward += r1
|
||||
s0 = s1
|
||||
agent.learn()
|
||||
agent.load_dict()
|
||||
print("episode", episode, "total_reward", total_reward)
|
||||
|
||||
path = os.path.realpath(args.ckpt_path)
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
ckpt_name = path + "/dqn.ckpt"
|
||||
save_checkpoint(agent.policy_net, ckpt_name)
|
Loading…
Reference in new issue