Twenty questions and decision trees | R bloggers

Twenty questions and decision trees | R bloggers

Most of us have probably played the game Twenty Questions. The answerer chooses something and the other players try to guess it by asking yes or no questions. The TV show ‘What’s My Line’ is an example of this, where the players try to guess the profession of a participant.

Contrary to my children’s attempts, the strategy should be to ask questions that roughly split the remaining possibilities in half each time.

This is very similar to a machine learning decision tree, albeit with an interesting distinction.

A decision tree cheats. The decision tree algorithm knows the answer (df$target = 1). The algorithm tries to find the best feature and split value to separate df$target = 1 from df$target = 0 on each node, but it needs to know the right answer to ask the best questions. This is why, if the game is played several times with different US presidents, the algorithm can choose different functions and split values.

Still, I thought it would be fun to program a decision tree model together with the American presidents. I found some data on presidents. I decided that some variables had too many values ​​(high cardinality – there were many political party names in the 19th century), so I grouped some values ​​together to reduce the number of unique values.

I initially started with a random integer between 1 and 47 to select a president, which selected President Hoover, but I found that a different president would create a tree that would be closer to the questions I would have asked if I were a player. So I picked President Reagan to get a more interesting tree.

(I considered selecting President Garfield to ask the question, “Does the President get a unique proof of the Pythagorean Theorem?”, but I decided that was a bit quirky even for me.)

Here is the resulting tree for President Reagan:

Here is the resulting plot of variable importance. Note that the variables are not in the same order as the tree splits. I understand that the importance of the variable is based on some improvements in all nodes where the variable was used as a splitter.

Here is my R code:


library(dplyr)
library(rpart)
library(rpart.plot)
library(ggplot2)
df <- read.csv("prez.csv", header=TRUE)   
# data file available at github:  
prez.csv
set.seed(123)
# r <- sample(1:nrow(df),1)
r <- 40   # deliberate choice to get longer tree
answer <- df$LastName[r]
print(paste("The target president is:", answer))
df$target <- rep(0, nrow(df))
df$target[r] <- 1

# Feature engineering:
df <- df %>%
  # A. Categorical Reduction
  mutate(
    Party = case_when(
        Party %in% c("Democratic") ~ "Democratic", 
        Party %in% c("Republican") ~ "Republican", 
        TRUE ~ "Other"),
    Occupation = case_when(
        Occupation %in% c("Businessman", "Lawyer") ~ Occupation, 
        TRUE ~ "Other"),
    State = case_when(
        State %in% c("New York") ~ "NY", 
        State %in% c("Ohio") ~ "OH", 
        State %in% c("Virginia") ~ "VA", 
        State %in% c("Massachusetts") ~ "MA", 
        State %in% c("Texas") ~ "TX", TRUE ~ "Other"),
    Religion = case_when(
        Religion %in% c("Episcopalian", "Presbyterian", "Unitarian", "Baptist", "Methodist") ~ "Main_Prot", 
        TRUE ~ "Other"),
    
    # B. Year/Century Binning using cut()
    DOB = cut(DOB, breaks = c(-Inf, 1800, 1900, 2000, Inf), 
        labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
    DOD = cut(DOD, breaks = c(-Inf, 1800, 1900, 2000, Inf), 
        labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
    Began = cut(Began, breaks = c(-Inf, 1800, 1900, 2000, Inf), 
        labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
    Ended = cut(Ended, breaks = c(-Inf, 1800, 1900, 2000, Inf), 
        labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE)
  ) %>%

  # C. Convert all new/existing binary/categorical variables to factor
  mutate_at(vars(Party, State, Occupation, Religion, Assassinated, Military, Terms_GT_1, Pres_During_War, Was_Veep, DOB, DOD, Began, Ended), as.factor)


# selected variables 
formula <- as.formula(target ~ Began + State + Party + Occupation + Pres_During_War)
# Using aggressive control settings to force a maximal, unpruned tree
prez_tree <- rpart(formula, data = df, method = "class",
                   control = rpart.control(cp = 0.001, minsplit = 2, minbucket = 1))
rpart.plot(prez_tree, type = 4, extra = 101, main = "President Twenty Questions")

# check Reagan
df %>% filter(Began == "20th century" & 
              !State %in% c("MA", "NY", "OH", "TX") &
              Party == "Republican" &
              !Occupation %in% c( "Businessman", "Lawyer"))

variable_importance <- prez_tree$variable.importance
importance_df <- data.frame(
  Variable = names(variable_importance),
  Importance = variable_importance
)

importance_df <- importance_df[order(importance_df$Importance, decreasing = TRUE), ]

common_theme <- theme(
        legend.position="NULL",
        plot.title = element_text(size=15, face="bold"),
        plot.subtitle = element_text(size=12.5, face="bold"),
        axis.title = element_text(size=15, face="bold"),
        axis.text = element_text(size=15, face="bold"),
        legend.title = element_text(size=15, face="bold"),
        legend.text = element_text(size=15, face="bold"))


ggplot(importance_df, aes(x = factor(Variable, levels = rev(Variable)), y = Importance)) +
  geom_col(aes(fill = Variable)) + 
  coord_flip() +
  labs(title = "20 Questions Variable Importance",
       x = "Variable",
       y = "Mean Decrease Gini") +
  common_theme

# loop through all presidents to see different first split vars
library(purrr)

### 1. Define the Analysis Function ###
# The function is modified to return a data frame row for clarity
get_first_split_row <- function(df, r) {
  # Temporarily set the target for the current president
  df$target <- 0
  df$target[r] <- 1
  tree <- rpart(formula, data = df, method = "class",
                control = rpart.control(cp = 0.001, minsplit = 2, minbucket = 1))
  frame <- tree$frame
  
  # Determine the result
  if (nrow(frame) > 1) {
    first_split_var <- as.character(frame$var[1])
  } else {
    first_split_var <- "No split"
  }
  
  # Return a single-row data frame
  return(data.frame(
    President = df$LastName[r],
    First_Split_Variable = first_split_var
  ))
}

### 2. Run the Analysis and Combine Results ###
# Create a list of row indices to iterate over
indices_to_run <- 1:nrow(df)

# Use map_dfr to apply the function to every index and combine the results 
# into a single data frame (dfr = data frame row bind)
first_split_results_df <- map_dfr(indices_to_run, ~ get_first_split_row(df, .x))

### 3. Display the Table and Original Analysis ###
# Display the resulting table:
print(first_split_results_df)

print(table(first_split_results_df$First_Split_Variable))
  
  

End


#Twenty #questions #decision #trees #bloggers

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *