離散cartpole環境が正常に学習しない
32 views (last 30 days)
Show older comments
ryuuzi
on 25 Oct 2024
Answered: Hiro Yoshino
on 5 Nov 2024 at 7:17
「create custom environment from class template」を参考に離散cartpole環境を作成して、強化学習デザイナーにインポートさせてみました。
しかし、学習が安定に収束してくれませんでした。試行錯誤してみましたが、対処法が思いつきませんでした。
教えてください
classdef matlab < rl.env.MATLABEnvironment
properties
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
MassCart = 1.0
% Mass of the pole
MassPole = 0.1
% Half the length of the pole
Length = 0.5
% Max Force the input can appy
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode
ThetaThresholdRadians = 12 * pi/180
% Distance at which to fail the episode
XThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -5
end
properties
% system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Internal flag to store stale env that is finished
IsDone = false
end
properties (Transient,Access = private)
Visualizer = []
end
methods
function this = matlab()%ObservationInfo, ActionInfo
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
this = this@rl.env.MATLABEnvironment(ObservationInfo, ActionInfo);
updateActionInfo(this);
end
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.Length(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Length');
this.Length = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.MassCart(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassCart');
this.MassCart = val;
end
function set.MassPole(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassPole');
this.MassPole = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.ThetaThresholdRadians(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','ThetaThresholdRadians');
this.ThetaThresholdRadians = val;
notifyEnvUpdated(this);
end
function set.XThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','XThreshold');
this.XThreshold = val;
notifyEnvUpdated(this);
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
function [observation,reward,isdone,loggedSignals] = step(this,action)
loggedSignals = [];
% Get action
force = getForce(this,action);
% Unpack state vector
state = this.State;
%x = state(1);
x_dot = state(2);
theta = state(3);
theta_dot = state(4);
% Apply motion equations
costheta = cos(theta);
sintheta = sin(theta);
totalmass = this.MassCart + this.MassPole;
polemasslength = this.MassPole*this.Length;
temp = (force + polemasslength * theta_dot * theta_dot * sintheta) / totalmass;
thetaacc = (this.Gravity * sintheta - costheta* temp) / (this.Length * (4.0/3.0 - this.MassPole * costheta * costheta / totalmass));
xacc = temp - polemasslength * thetaacc * costheta / totalmass;
% Euler integration
observation = state + this.Ts.*[x_dot;xacc;theta_dot;thetaacc];
this.State = observation;
x = observation(1);
theta = observation(3);
isdone = abs(x) > this.XThreshold || abs(theta) > this.ThetaThresholdRadians;
this.IsDone = isdone;
% Get reward
reward = getReward(this,x,force);
end
function initialState = reset(this)
% Randomize the initial pendulum angle between (+- .05 rad)
% Theta (+- .05 rad)
T0 = 2*0.05*rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
initialState= [X0;Xd0;T0;Td0];
this.State = initialState;
end
function varargout = plot(this)
% Visualizes the environment
if isempty(this.Visualizer) || ~isvalid(this.Visualizer)
this.Visualizer = rl.env.viz.CartPoleVisualizer(this);
else
bringToFront(this.Visualizer);
end
if nargout
varargout{1} = this.Visualizer;
end
end
end
methods (Access = protected)
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error(message('rl:env:CartPoleDiscreteInvalidAction',sprintf('%g',-this.MaxForce),sprintf('%g',this.MaxForce)));
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 10];
end
function Reward = getReward(this,~,~)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
end
end
0 Comments
Accepted Answer
Hiro Yoshino
on 5 Nov 2024 at 7:17
に離散 cartpole が有るので、動作するものを開いて中身を調べてみると参考になる (答えが有る) かもしれません
0 Comments
More Answers (0)
See Also
Categories
Find more on ビッグ データの処理 in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!