classdef TD3Agent < handle
properties
actor
critic
actor_lr
critic_lr
target_actor
target_critic
replay_buffer
discount_factor
tau
policy_noise
noise_clip
policy_delay
end
methods
function obj = TD3Agent(state_dim, action_dim, max_action, actor_lr, critic_lr, discount_factor, tau, policy_noise, noise_clip, policy_delay)
obj.actor = ActorNetwork(state_dim, action_dim, max_action);
obj.critic = CriticNetwork(state_dim, action_dim);
obj.target_actor = ActorNetwork(state_dim, action_dim, max_action);
obj.target_critic = CriticNetwork(state_dim, action_dim);
obj.replay_buffer = ReplayBuffer(1000000, state_dim, action_dim);
obj.discount_factor = discount_factor;
obj.tau = tau;
obj.policy_noise = policy_noise;
obj.noise_clip = noise_clip;
obj.policy_delay = policy_delay;
obj.critic_lr = critic_lr;
obj.actor_lr = actor_lr;
% 初始化目标网络权重与Actor网络一致
obj.target_actor.net.Learnables.Value = obj.actor.net.Learnables.Value;
obj.target_critic.net.Learnables.Value = obj.critic.net.Learnables.Value;
end
function action = selectAction(obj, state)
action = obj.actor.forward(state);
noise = obj.policy_noise * randn(size(action));
noise = max(min(noise, obj.noise_clip), -obj.noise_clip);
action = action + noise;
action = max(min(action, 1), -1); % 保持动作在[-1, 1]范围内
end
function train(obj, batch_size)
[states, actions, rewards, next_states, done] = obj.replay_buffer.sample(batch_size);
% 使用dlfeval计算Critic梯度
critic_gradients = dlfeval(@(s, a, r, ns, d) obj.computeCriticGradients(s, a, r, ns, d), ...
states, actions, rewards, next_states, done);
% 更新Critic网络
critic_optimizer = adamupdate('LearnRate', obj.critic_lr, 'Beta1', 0.9, 'Beta2', 0.999);
[obj.critic.net, ~] = update(obj.critic.net, critic_gradients, critic_optimizer);
% 延迟更新Actor网络
if mod(batch_size, obj.policy_delay) == 0
actor_gradients = dlfeval(@(s) obj.computeActorGradients(s), states);
actor_optimizer = adamupdate('LearnRate', obj.actor_lr, 'Beta1', 0.9, 'Beta2', 0.999);
[obj.actor.net, ~] = update(obj.actor.net, actor_gradients, actor_optimizer);
% 更新目标网络
obj.updateTargetNetworks();
end
end
function gradients = computeCriticGradients(obj, states, actions, rewards, next_states, done)
loss = obj.compute_critic_loss(states, actions, rewards, next_states, done);
gradients = dlgradient(loss, obj.critic.net.Learnables);
end
function gradients = computeActorGradients(obj, states)
actor_actions = obj.actor.forward(states);
actor_loss = -mean(obj.critic.forward(states, actor_actions), 'all');
gradients = dlgradient(actor_loss, obj.actor.net.Learnables);
end
function loss = compute_critic_loss(obj, states, actions, rewards, next_states, done)
next_actions = obj.target_actor.forward(next_states);
noise = obj.policy_noise * randn(size(next_actions));
noise = max(min(noise, obj.noise_clip), -obj.noise_clip);
next_actions = next_actions + noise;
target_Q = obj.target_critic.forward(next_states, next_actions);
target_Q = rewards + (1 - done) .* obj.discount_factor .* target_Q;
current_Q = obj.critic.forward(states, actions);
loss = mean((current_Q - target_Q).^2, 'all');
loss = dlarray(loss, 'CB');
end
function updateTargetNetworks(obj)
obj.target_actor.net.Learnables.Value = obj.tau * obj.actor.net.Learnables.Value + ...
(1 - obj.tau) * obj.target_actor.net.Learnables.Value;
obj.target_critic.net.Learnables.Value = obj.tau * obj.critic.net.Learnables.Value + ...
(1 - obj.tau) * obj.target_critic.net.Learnables.Value;
end
end
end