pull/13184/head
xiaoyisd 4 years ago
parent e206e712f3
commit a941422337

@ -0,0 +1,128 @@
# Contents
- [DQN Description](#DQN-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Requirements](#Requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DQN Description](#contents)
DQN is the first deep learning model to successfully learn control policies directly from high-dimensional sensory input using reinforcement learning.
[Paper](https://www.nature.com/articles/nature14236) Mnih, Volodymyr, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves et al. "Human-level control through deep reinforcement learning." nature 518, no. 7540 (2015): 529-533.
## [Model Architecture](#content)
The overall network architecture of DQN is show below:
[Paper](https://www.nature.com/articles/nature14236)
## [Dataset](#content)
## [Requirements](#content)
- HardwareAscend/GPU/CPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- third-party libraries
```bash
pip install gym
```
## [Script Description](#content)
### [Scripts and Sample Code](#contents)
```python
├── dqn
├── README.md # descriptions about DQN
├── scripts
│ ├──run_standalone_eval_ascend.sh # shell script for evaluation with Ascend
│ ├──run_standalone_eval_gpu.sh # shell script for evaluation with GPU
│ ├──run_standalone_train_ascend.sh # shell script for train with Ascend
│ ├──run_standalone_train_gpu.sh # shell script for train with GPU
├── src
│ ├──agent.py # model agent
│ ├──config.py # parameter configuration
│ ├──dqn.py # dqn architecture
├── train.py # training script
├── eval.py # evaluation script
```
### [Script Parameter](#content)
```python
'gamma': 0.8 # the proportion of choose next state value
'epsi_high': 0.9 # the highest exploration rate
'epsi_low': 0.05 # the Lowest exploration rate
'decay': 200 # number of steps to start learning
'lr': 0.001 # learning rate
'capacity': 100000 # the capacity of data buffer
'batch_size': 512 # training batch size
'state_space_dim': 4 # the environment state space dim
'action_space_dim': 2 # the action dim
```
### [Training Process](#content)
```shell
# training example
python
Ascend: python train.py --device_target Ascend --ckpt_path ckpt > log.txt 2>&1 &
GPU: python train.py --device_target GPU --ckpt_path ckpt > log.txt 2>&1 &
shell:
Ascend: sh run_standalone_train_ascend.sh ckpt
GPU: sh run_standalone_train_gpu.sh ckpt
```
### [Evaluation Process](#content)
```shell
# evaluat example
python
Ascend: python eval.py --device_target Ascend --ckpt_path .ckpt/checkpoint_dqn.ckpt
GPU: python eval.py --device_target GPU --ckpt_path .ckpt/checkpoint_dqn.ckpt
shell:
Ascend: sh run_standalone_eval_ascend.sh .ckpt/checkpoint_dqn.ckpt
GPU: sh run_standalone_eval_gpu.sh .ckpt/checkpoint_dqn.ckpt
```
## [Performance](#content)
### Inference Performance
| Parameters | DQN |
| -------------------------- | ----------------------------------------------------------- |
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G |
| uploaded Date | 03/10/2021 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Training Parameters | batch_size = 512, lr=0.001 |
| Optimizer | RMSProp |
| Loss Function | MSELoss |
| outputs | probability |
| Params (M) | 7.3k |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/rl/dqn |
## [Description of Random Situation](#content)
We use random seed in train.py.
## [ModeZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for DQN"""
import argparse
import gym
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_dqn as cfg
from src.agent import Agent
parser = argparse.ArgumentParser(description='MindSpore dqn Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--ckpt_path', type=str, default=None, help='if is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
set_seed(1)
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
env = gym.make('CartPole-v0')
cfg.state_space_dim = env.observation_space.shape[0]
cfg.action_space_dim = env.action_space.n
agent = Agent(**cfg)
# load checkpoint
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
not_load_param = load_param_into_net(agent.policy_net, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
score = 0
agent.load_dict()
for episode in range(50):
s0 = env.reset()
total_reward = 1
while True:
a0 = agent.eval_act(s0)
s1, r1, done, _ = env.step(a0)
if done:
r1 = -1
if done:
break
total_reward += r1
s0 = s1
score += total_reward
print("episode", episode, "total_reward", total_reward)
print("mean_reward", score/50)

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
CKPT_PATH=$1
python -s ${self_path}/../eval.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
CKPT_PATH=$1
python -s ${self_path}/../eval.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
CKPT_PATH=$1
python -s ${self_path}/../train.py --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
CKPT_PATH=$1
python -s ${self_path}/../train.py --device_target="GPU" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Agent of reinforcement learning network"""
import random
import math
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from src.dqn import DQN, WithLossCell
class Agent:
"""
DQN Agent
"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.policy_net = DQN(self.state_space_dim, 256, self.action_space_dim)
self.target_net = DQN(self.state_space_dim, 256, self.action_space_dim)
self.optimizer = nn.RMSProp(self.policy_net.trainable_params(), learning_rate=self.lr)
loss_fn = nn.MSELoss()
loss_q_net = WithLossCell(self.policy_net, loss_fn)
self.policy_net_train = nn.TrainOneStepCell(loss_q_net, self.optimizer)
self.policy_net_train.set_train(mode=True)
self.buffer = []
self.steps = 0
def act(self, s0):
"""
Agent choose action.
"""
self.steps += 1
epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay))
if random.random() < epsi:
a0 = random.randrange(self.action_space_dim)
else:
s0 = np.expand_dims(s0, axis=0)
s0 = Tensor(s0, mstype.float32)
a0 = self.policy_net(s0).asnumpy()
a0 = np.argmax(a0)
return a0
def eval_act(self, s0):
self.steps += 1
s0 = np.expand_dims(s0, axis=0)
s0 = Tensor(s0, mstype.float32)
a0 = self.policy_net(s0).asnumpy()
a0 = np.argmax(a0)
return a0
def put(self, *transition):
if len(self.buffer) == self.capacity:
self.buffer.pop(0)
self.buffer.append(transition)
def load_dict(self):
for target_item, source_item in zip(self.target_net.parameters_dict(), self.policy_net.parameters_dict()):
target_param = self.target_net.parameters_dict()[target_item]
source_param = self.policy_net.parameters_dict()[source_item]
target_param.set_data(source_param.data)
def learn(self):
"""
Agent learn from experience data.
"""
if (len(self.buffer)) < self.batch_size:
return
samples = random.sample(self.buffer, self.batch_size)
s0, a0, r1, s1 = zip(*samples)
s1 = Tensor(s1, mstype.float32)
s0 = Tensor(s0, mstype.float32)
a0 = Tensor(np.expand_dims(a0, axis=1))
next_state_values = self.target_net(s1).asnumpy()
next_state_values = np.max(next_state_values, axis=1)
y_true = r1 + self.gamma * next_state_values
y_true = Tensor(np.expand_dims(y_true, axis=1), mstype.float32)
self.policy_net_train(s0, a0, y_true)

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as edict
config_dqn = edict({
'gamma': 0.8,
'epsi_high': 0.9,
'epsi_low': 0.05,
'decay': 200,
'lr': 0.001,
'capacity': 100000,
'batch_size': 512,
'state_space_dim': 4,
'action_space_dim': 2
})

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DQN net"""
import mindspore.nn as nn
import mindspore.ops as ops
class DQN(nn. Cell):
def __init__(self, input_size, hidden_size, output_size):
super(DQN, self).__init__()
self.linear1 = nn.Dense(input_size, hidden_size)
self.linear2 = nn.Dense(hidden_size, output_size)
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(self.linear1(x))
return self.linear2(x)
class WithLossCell(nn.Cell):
"""
network with loss function
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
self.gather = ops.GatherD()
def construct(self, x, act, label):
out = self._backbone(x)
out = self.gather(out, 1, act)
loss = self._loss_fn(out, label)
return loss

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train DQN and get checkpoint files."""
import os
import argparse
import gym
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint
from src.config import config_dqn as cfg
from src.agent import Agent
parser = argparse.ArgumentParser(description='MindSpore dqn Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
set_seed(1)
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
env = gym.make('CartPole-v0')
cfg.state_space_dim = env.observation_space.shape[0]
cfg.action_space_dim = env.action_space.n
agent = Agent(**cfg)
agent.load_dict()
for episode in range(150):
s0 = env.reset()
total_reward = 1
while True:
a0 = agent.act(s0)
s1, r1, done, _ = env.step(a0)
if done:
r1 = -1
agent.put(s0, a0, r1, s1)
if done:
break
total_reward += r1
s0 = s1
agent.learn()
agent.load_dict()
print("episode", episode, "total_reward", total_reward)
path = os.path.realpath(args.ckpt_path)
if not os.path.exists(path):
os.makedirs(path)
ckpt_name = path + "/dqn.ckpt"
save_checkpoint(agent.policy_net, ckpt_name)
Loading…
Cancel
Save