How to Use MATLAB for Deep Reinforcement Learning (with Code)

Version Compatibility: R2021a and above

Toolbox Requirements: Reinforcement Learning Toolbox + Deep Learning Toolbox

1. Introduction

Deep Reinforcement Learning (DRL) is a class of intelligent decision-making algorithms that combines deep learning with reinforcement learning. The core idea is to learn a policy that maximizes long-term cumulative rewards through interaction with the environment.

Starting from R2019b, MATLAB introduced the Reinforcement Learning Toolbox, providing complete training, simulation, and deployment support for DRL. Users can quickly build and train agents without writing low-level gradient update code.

2. Basic Components of Deep Reinforcement Learning in MATLAB

A typical DRL system in MATLAB consists of the following five components:

Module Function
Environment Defines the state space, action space, and reward mechanism; provides an interaction interface.
Agent Includes the policy (Actor) and value evaluation (Critic), responsible for learning.
Neural Network Approximates the value function or policy function.
Training Options Controls the training process, such as the number of episodes, exploration rate, parallelization, etc.
Simulation & Deployment Evaluates agent performance or generates code.

3. Environment Definition Methods

MATLAB supports three methods for building environments:

3.1 Using Predefined Environments

Several classic environments are provided officially, such as:

env = rlPredefinedEnv('CartPole-Discrete');

Common environments include:

  • <span>'CartPole-Discrete'</span>: Discrete action control;
  • <span>'CartPole-Continuous'</span>: Continuous action control;
  • <span>'SimplePendulum-Continuous'</span>, etc.

3.2 Custom Function Environment

Users can quickly create environments by customizing the <span>step</span> and <span>reset</span> functions:

env = rlFunctionEnv(obsInfo, actInfo, 'StepFcn', @myStep, 'ResetFcn', @myReset);

Where:

  • <span>obsInfo</span> and <span>actInfo</span> define the state and action spaces;
  • <span>myStep</span> function returns <span>[nextObs, reward, isDone, logInfo]</span>.

3.3 Simulink Environment

For control system tasks, environments can be created through Simulink models:

env = rlSimulinkEnv('myModel','RL Agent', obsInfo, actInfo);

4. Building Deep Networks

In MATLAB, deep neural networks are constructed using <span>layerGraph</span> + <span>dlnetwork</span>. The network structure depends on the type of task:

Task Type Network Output Common Algorithms
Discrete Action Q values for each action DQN, PG
Continuous Action Action or distribution parameters DDPG, PPO, SAC

Example (Discrete Action DQN Critic):

numObs  = obsInfo.Dimension(1);
numActs = numel(actInfo.Elements);

layers = [
    featureInputLayer(numObs,'Normalization','none','Name','state')
    fullyConnectedLayer(64,'Name','fc1')
    reluLayer('Name','relu1')
    fullyConnectedLayer(64,'Name','fc2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(numActs,'Name','Qout')
];

criticNet = dlnetwork(layerGraph(layers));
critic = rlQValueRepresentation(criticNet, obsInfo, actInfo, 'Observation',{'state'});

Note: In R2022a and later versions, the function name has changed to <span>rlQValueFunction</span>.

5. Creating Agents

5.1 Discrete Action: DQN

epsOpts = rl.option.EpsilonGreedyExploration( ...
    'Epsilon',1.0,'EpsilonMin',0.01,'EpsilonDecay',1e-4);

agentOpts = rlDQNAgentOptions( ...
    'UseDoubleDQN', true, ...
    'TargetUpdateFrequency', 4, ...
    'ExperienceBufferLength', 1e6, ...
    'MiniBatchSize', 256, ...
    'DiscountFactor', 0.99, ...
    'EpsilonGreedyExploration', epsOpts);

agent = rlDQNAgent(critic, agentOpts);

5.2 Continuous Action: DDPG / PPO / SAC

Continuous control tasks require defining both the Actor and Critic networks. MATLAB provides the following classes:

  • <span>rlContinuousDeterministicActor</span>
  • <span>rlValueRepresentation</span> / <span>rlQValueRepresentation</span>
  • <span>rlDDPGAgent</span>, <span>rlPPOAgent</span>, <span>rlSACAgent</span>

6. Training and Evaluation

6.1 Training the Agent

trainOpts = rlTrainingOptions( ...
    'MaxEpisodes', 500, ...
    'MaxStepsPerEpisode', 500, ...
    'StopTrainingCriteria','AverageReward', ...
    'StopTrainingValue',475, ...
    'ScoreAveragingWindowLength',20, ...
    'Verbose',false, ...
    'Plots','training-progress');

trainingStats = train(agent, env, trainOpts);

A real-time curve window will pop up during training, showing the rewards per episode.

6.2 Evaluation and Simulation

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env, agent, simOptions);

6.3 Saving the Model

save('trainedAgent.mat','agent');

Appendix: Complete DQN Example (R2021a Compatible)

env = rlPredefinedEnv('CartPole-Discrete');
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

numObs  = obsInfo.Dimension(1);
numActs = numel(actInfo.Elements);

layers = [
    featureInputLayer(numObs,'Normalization','none','Name','state')
    fullyConnectedLayer(64,'Name','fc1')
    reluLayer('Name','relu1')
    fullyConnectedLayer(64,'Name','fc2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(numActs,'Name','Qout')
];

criticNet = dlnetwork(layerGraph(layers));
critic = rlQValueRepresentation(criticNet, obsInfo, actInfo, 'Observation',{'state'});

epsOpts = rl.option.EpsilonGreedyExploration( ...
    'Epsilon',1.0,'EpsilonMin',0.01,'EpsilonDecay',1e-4);

agentOpts = rlDQNAgentOptions( ...
    'UseDoubleDQN',true, ...
    'TargetUpdateFrequency',4, ...
    'ExperienceBufferLength',1e6, ...
    'MiniBatchSize',256, ...
    'DiscountFactor',0.99, ...
    'EpsilonGreedyExploration',epsOpts);

agent = rlDQNAgent(critic, agentOpts);

trainOpts = rlTrainingOptions( ...
    'MaxEpisodes',500, ...
    'MaxStepsPerEpisode',500, ...
    'StopTrainingCriteria','AverageReward', ...
    'StopTrainingValue',475, ...
    'ScoreAveragingWindowLength',20, ...
    'Verbose',false, ...
    'Plots','training-progress');

rng(0);
trainingStats = train(agent, env, trainOpts);
sim(env, agent);

Leave a Comment