SARSA (State-Action-Reward-State-Action)#

Alt text

In this lesson, we’ll explore SARSA, a value-based reinforcement learning algorithm that is closely related to Q-Learning. However, unlike Q-learning, SARSA is an on-policy algorithm. This means that the Q-value updates are based on the actions the agent actually chooses, rather than the maximum possible action in the next state.

The Q-Value Function in SARSA#

Similar to Q-learning, SARSA aims to learn the Q-value, \(Q(s,a)\), which is the expected cumulative reward starting from state \(s\), taking action \(a\), and following the current policy.

The key update equation for SARSA is:

\[ Q(s,a)←Q(s,a)+α[r+γQ(s′,a′)−Q(s,a)] \]

Here:

  • \(α\) is the learning rate, controlling how much we update the Q-value,

  • \(r\) is the reward received after taking action \(a\) in state \(s\),

  • \(s′\) is the next state,

  • \(a′\) is the action taken in the next state \(s′\) according to the policy,

  • \(γ\) is the discount factor, balancing immediate vs. future rewards.

In simpler terms, SARSA updates the Q-value based on the actual action taken in the next state, following the current policy.

Temporal Difference Learning in SARSA#

Similar to Q-learning, SARSA relies on Temporal Difference (TD) Learning, that is, learning without having to wait for the episode to finnish, but with the difference that SARSA is on-policy. The TD target in SARSA is:

\[ TD target=r+γQ(s′,a′) \]

The TD error, which drives the update, is:

\[ TD error=[r+γQ(s′,a′)−Q(s,a)] \]

Again, unlike Q-learning which uses the max Q-value for the next state, SARSA uses the Q-value of the action actually taken by the agent in state \(s′\).

The SARSA Process#

The SARSA algorithm works similarly to Q-learning but updates the Q-values using the actions that the agent actually performs. The process follows these steps:

  1. Initialize Q-values: Start by initializing the Q-table (or function) for all state-action pairs arbitrarily (often to zeros).

  2. For each episode and each step:

    • Observe the current state \(s\),

    • Choose an action aa using an exploration strategy like epsilon-greedy: With probability ϵϵ, choose a random action (exploration), Otherwise, choose the action with the highest Q-value (exploitation),

    • Execute action \(a\), observe the reward \(r\), and the next state \(s′\),

    • Choose the next action \(a′\) based on the current policy (this is the key difference from Q-learning),

    • Update the Q-value for \((s,a)\) using the SARSA update rule:

    \[ Q(s,a)←Q(s,a)+α[r+γQ(s′,a′)−Q(s,a)] \]
    • Set the current state to \(s′\), and the current action to \(a′\), and repeat until the episode ends.

  3. Repeat for many episodes: Over time, the Q-values should converge, and the agent will learn the best actions to take under the current policy.

Epsilon-Greedy Exploration in SARSA#

Like Q-learning, SARSA also uses the epsilon-greedy exploration strategy to balance exploration and exploitation. With probability ϵϵ, the agent chooses a random action, and with probability 1−ϵ1−ϵ, it chooses the action that has the highest Q-value according to the current policy.

Coding SARSA#

import torch
import numpy as np
from torch import nn

class SARSACar():

    def __init__(self,):

        self.model = self.create_model() # 1. Intiliaze Q values

    def act_epsilon_greedy(self, state):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        if np.random.random() < self.epsilon:
            return np.random.randint(0, self.output_size)
        else:
            return int(np.argmax(self.get_qs(state)))
        
    def action_train(self, state):

        action = self.act_epsilon_greedy(state)

        if action == 0:
            self.angle += 10  # Left
        elif action == 1:
            self.angle -= 10  # Right
        elif action == 2:
            if self.speed - 2 >= 6:
                self.speed -= 2  # Slow Down
        else:
            self.speed += 2  # Speed Up

        return action
    

    def train(self, state, action, reward, new_state, done):


        state_tensor = (
            torch.tensor(np.array(state), dtype=torch.float32)
            .unsqueeze(0)
            .to(self.device)
        ) 
        new_state_tensor = (
            torch.tensor(np.array(new_state), dtype=torch.float32)
            .unsqueeze(0)
            .to(self.device)
        )

        # Get the Q-value for the current state-action pair
        current_q_values = self.model(state_tensor)  
        current_q_value = current_q_values.gather(
            1, torch.tensor([[action]], dtype=torch.long).to(self.device)
        ) 
        current_q_value = current_q_value.squeeze(1)  

        # If the episode is done, next_q_value should be 0
        if done:
            next_q_value = torch.tensor([0.0], dtype=torch.float32).to(self.device)
        else:
            # Get the next action using the epsilon-greedy policy for SARSA
            next_action = self.act_epsilon_greedy(new_state)
            next_q_values = self.model(new_state_tensor)  
            next_q_value = next_q_values.gather(
                1, torch.tensor([[next_action]], dtype=torch.long).to(self.device)
            )  
            next_q_value = next_q_value.squeeze(1) 

        # Compute the target Q-value using the SARSA update rule
        target_q_value = (
            torch.tensor([reward], dtype=torch.float32).to(self.device)
            + self.discount_factor * next_q_value
        )

        # Calculate the loss between the current Q-value and target Q-value
        loss = nn.MSELoss()(current_q_value, target_q_value)

        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon for exploration-exploitation tradeoff
        self.epsilon_decay()

        return loss.item()
class SARSARace():

    def training_race(self, car: SARSACar, episodes):

        for episode in range(1, episodes + 1): #2. For each episode 

            current_state = car.get_data() # Observe the current state
            done = False
            episode_reward = 0
            while not done:

                action = car.action_train(current_state) # Choose an action and execute it
                new_state, reward, done = self.step(car) # Observe the new state
                episode_reward += reward

                next_action = car.act_epsilon_greedy(new_state)
                loss = car.train(current_state, action, reward, new_state, done) #Update the Q-values

                current_state = new_state
                action = next_action
                

Actual training#