SARSA (State-Action-Reward-State-Action)#
In this lesson, we’ll explore SARSA, a value-based reinforcement learning algorithm that is closely related to Q-Learning. However, unlike Q-learning, SARSA is an on-policy algorithm. This means that the Q-value updates are based on the actions the agent actually chooses, rather than the maximum possible action in the next state.
The Q-Value Function in SARSA#
Similar to Q-learning, SARSA aims to learn the Q-value, \(Q(s,a)\), which is the expected cumulative reward starting from state \(s\), taking action \(a\), and following the current policy.
The key update equation for SARSA is:
Here:
\(α\) is the learning rate, controlling how much we update the Q-value,
\(r\) is the reward received after taking action \(a\) in state \(s\),
\(s′\) is the next state,
\(a′\) is the action taken in the next state \(s′\) according to the policy,
\(γ\) is the discount factor, balancing immediate vs. future rewards.
In simpler terms, SARSA updates the Q-value based on the actual action taken in the next state, following the current policy.
Temporal Difference Learning in SARSA#
Similar to Q-learning, SARSA relies on Temporal Difference (TD) Learning, that is, learning without having to wait for the episode to finnish, but with the difference that SARSA is on-policy. The TD target in SARSA is:
The TD error, which drives the update, is:
Again, unlike Q-learning which uses the max Q-value for the next state, SARSA uses the Q-value of the action actually taken by the agent in state \(s′\).
The SARSA Process#
The SARSA algorithm works similarly to Q-learning but updates the Q-values using the actions that the agent actually performs. The process follows these steps:
Initialize Q-values: Start by initializing the Q-table (or function) for all state-action pairs arbitrarily (often to zeros).
For each episode and each step:
Observe the current state \(s\),
Choose an action aa using an exploration strategy like epsilon-greedy: With probability ϵϵ, choose a random action (exploration), Otherwise, choose the action with the highest Q-value (exploitation),
Execute action \(a\), observe the reward \(r\), and the next state \(s′\),
Choose the next action \(a′\) based on the current policy (this is the key difference from Q-learning),
Update the Q-value for \((s,a)\) using the SARSA update rule:
\[ Q(s,a)←Q(s,a)+α[r+γQ(s′,a′)−Q(s,a)] \]Set the current state to \(s′\), and the current action to \(a′\), and repeat until the episode ends.
Repeat for many episodes: Over time, the Q-values should converge, and the agent will learn the best actions to take under the current policy.
Epsilon-Greedy Exploration in SARSA#
Like Q-learning, SARSA also uses the epsilon-greedy exploration strategy to balance exploration and exploitation. With probability ϵϵ, the agent chooses a random action, and with probability 1−ϵ1−ϵ, it chooses the action that has the highest Q-value according to the current policy.
Coding SARSA#
import torch
import numpy as np
from torch import nn
class SARSACar():
def __init__(self,):
self.model = self.create_model() # 1. Intiliaze Q values
def act_epsilon_greedy(self, state):
state = torch.tensor(state, dtype=torch.float32).to(self.device)
if np.random.random() < self.epsilon:
return np.random.randint(0, self.output_size)
else:
return int(np.argmax(self.get_qs(state)))
def action_train(self, state):
action = self.act_epsilon_greedy(state)
if action == 0:
self.angle += 10 # Left
elif action == 1:
self.angle -= 10 # Right
elif action == 2:
if self.speed - 2 >= 6:
self.speed -= 2 # Slow Down
else:
self.speed += 2 # Speed Up
return action
def train(self, state, action, reward, new_state, done):
state_tensor = (
torch.tensor(np.array(state), dtype=torch.float32)
.unsqueeze(0)
.to(self.device)
)
new_state_tensor = (
torch.tensor(np.array(new_state), dtype=torch.float32)
.unsqueeze(0)
.to(self.device)
)
# Get the Q-value for the current state-action pair
current_q_values = self.model(state_tensor)
current_q_value = current_q_values.gather(
1, torch.tensor([[action]], dtype=torch.long).to(self.device)
)
current_q_value = current_q_value.squeeze(1)
# If the episode is done, next_q_value should be 0
if done:
next_q_value = torch.tensor([0.0], dtype=torch.float32).to(self.device)
else:
# Get the next action using the epsilon-greedy policy for SARSA
next_action = self.act_epsilon_greedy(new_state)
next_q_values = self.model(new_state_tensor)
next_q_value = next_q_values.gather(
1, torch.tensor([[next_action]], dtype=torch.long).to(self.device)
)
next_q_value = next_q_value.squeeze(1)
# Compute the target Q-value using the SARSA update rule
target_q_value = (
torch.tensor([reward], dtype=torch.float32).to(self.device)
+ self.discount_factor * next_q_value
)
# Calculate the loss between the current Q-value and target Q-value
loss = nn.MSELoss()(current_q_value, target_q_value)
# Backpropagation
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Decay epsilon for exploration-exploitation tradeoff
self.epsilon_decay()
return loss.item()
class SARSARace():
def training_race(self, car: SARSACar, episodes):
for episode in range(1, episodes + 1): #2. For each episode
current_state = car.get_data() # Observe the current state
done = False
episode_reward = 0
while not done:
action = car.action_train(current_state) # Choose an action and execute it
new_state, reward, done = self.step(car) # Observe the new state
episode_reward += reward
next_action = car.act_epsilon_greedy(new_state)
loss = car.train(current_state, action, reward, new_state, done) #Update the Q-values
current_state = new_state
action = next_action