library(knitr)
library(rmdformats)

## Global options
options(max.print="75")
opts_chunk$set(echo=FALSE,
                 cache=TRUE,
               prompt=FALSE,
               tidy=TRUE,
               comment=NA,
               message=FALSE,
               warning=FALSE)
opts_knit$set(width=75)

Arboles de decisión con R - Clasificación

Basado en textos de: Juan Bosco Mendoza Vega 23 de abril de 2018

Resumen

En este artículo revisaremos lo esencial para implementar árboles de decisión en R, en particular el caso de los árboles de clasificación, usando el paquete rpart. Utilizaremos un conjunto de datos usado frecuentemente para probar métodos de aprendizaje automático en nuestro ejemplo y durante el proceso daremos también un vistazo a algunos problemas comunes al procesar información en R.

Una introducción informal a los árboles de decisión

Los árboles de decisión son un método usado en distintas disciplinas como modelo de predicción. Estos son similares a diagramas de flujo, en los que llegamos a puntos en los que se toman decisiones de acuerdo a una regla.

En el campo del aprendizaje automático, hay distintas maneras de obtener árboles de decisión, la que usaremos en esta ocasión es conocida como CART: Classification And Regression Trees. Esta es una técnica de aprendizaje supervisado. Tenemos una variable objetivo (dependiente) y nuestra meta es obtener una función que nos permita predecir, a partir de variables predictoras (independientes), el valor de la variable objetivo para casos desconocidos.

Como el nombre indica, CART es una técnica con la que se pueden obtener árboles de clasificación y de regresión. Usamos clasificación cuando nuestra variable objetivo es discreta, mientras que usamos regresión cuando es continua. Nosotros tendremos una variable discreta, así que haremos clasificación.

La implementación particular de CART que usaremos es conocida como Recursive Partitioning and Regression Trees o RPART. De allí el nombre del paquete que utilizaremos en nuestro ejemplo.

De manera general, lo que hace este algoritmo es encontrar la variable independiente que mejor separa nuestros datos en grupos, que corresponden con las categorías de la variable objetivo. Esta mejor separación es expresada con una regla. A cada regla corresponde un nodo.

Por ejemplo, supongamos que nuestra variable objetivo tiene dos niveles, deudor y no deudor. Encontramos que la variable que mejor separa nuestros datos es ingreso mensual, y la regla resultante es que ingreso \(mensual > X pesos\). Esto quiere decir que los datos para los que esta regla es verdadera, tienen más probabilidad de pertenecer a un grupo, que al otro. En este ejemplo, digamos que si la regla es verdadera, un caso tiene más probabilidad de formar parte del grupo no deudor.

Una vez hecho esto, los datos son separados (particionados) en grupos a partir de la regla obtenida. Después, para cada uno de los grupos resultantes, se repite el mismo proceso. Se busca la variable que mejor separa los datos en grupos, se obtiene una regla, y se separan los datos. Hacemos esto de manera recursiva hasta que nos es imposible obtener una mejor separación. Cuando esto ocurre, el algoritmo se detiene. Cuando un grupo no puede ser partido mejor, se le llama nodo terminal u hoja.

Una característica muy importante en este algoritmo es que una vez que alguna variable ha sido elegida para separar los datos, ya no es usada de nuevo en los grupos que ha creado. Se buscan variables distintas que mejoren la separación de los datos.

Además, supongamos después de una partición que hemos creado dos grupos, A y B. Es posible que para el grupo A, la variable que mejor separa estos datos sea diferente a la que mejor separa los datos en el grupo B. Una vez que los grupos se han separado, al algoritmo “no ve” lo que ocurre entre grupos, estos son independientes entre sí y las reglas que aplican para ellos no afectan en nada a los demás.

El resultado de todo el proceso anterior es una serie de bifurcaciones que tiene la apariencia de un árbol que va creciendo ramas, de allí el nombre del procedimiento (aunque a mí en realidad me parece más parecido a la raíz del árbol que a las ramas).

Las principales ventajas de este método son su interpretabilidad, pues nos da un conjunto de reglas a partir de las cuales se pueden tomar decisiones. Este es un algoritmo que no es demandante en poder de cómputo comparado con procedimientos más sofisticados y, a pesar de ello, que tiende a dar buenos resultados de predicción para muchos tipos de datos.

Sus principales desventajas son que este en tipo de clasificación “débil”, pues sus resultados pueden variar mucho dependiendo de la muestra de datos usados para entrenar un modelo. Además es fácil sobre ajustar los modelos, esto es, hacerlos excelentes para clasificar datos que conocemos, pero deficientes para datos conocidos.

Para saber más sobre este algoritmo, en particular que quiere decir eso de mejor separación, puedes leer el siguiente documento, que también llamar con vignette(topic = “longintro”, package = “rpart”):

https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf

Paquetes necesarios

Usaremos los siguientes paquetes.

  • tidyverse: para llamar a la familia de paquetes tidyverse, que nos ayudaran al procesamiento de nuestros datos.
  • rpart: el paquete con la implementación de árboles de clasificación que utilizaremos.
  • rpart.plot: para graficar los resultados de rpart.
  • caret: un paquete con utilidades para clasificación y regresión. Lo usaremos por su función para crear matrices de confusión

Importando nuestros datos

Descargaremos el conjunto de datos de vino, disponible en el Machine Learning Repository.

https://archive.ics.uci.edu/ml/datasets/Wine Necesitamos descargar dos archivos. El primero contiene los datos que usaremos, y el segundo contiene su descripción (metadatos), la cual nos será de gran utilidad más adelante.

# Datos
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data", "wine.data")

# Información
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.names", "wine.names")

Como habrás notado, nuestros datos tienen una extensión de archivo no convencional: .data. En R no existe una función específica para leer archivos con esta extensión, similar a red.csv() o read.dat(), las cuales nos facilitan tarea de importar archivos de formatos específicos. Lo mismmo pasa con el archivo con su descripción, que tiene la extensión .name.

Necesitamos explorar estos archivos para saber cómo podemos leerlos. Para estos casos, usamos la función readLines(), que lee archivos, línea por línea, independientemente de su extensión o formato. con el argumento n = 10 indicamos que sólo deseamos leer las primeras diez líneas de cada archivo.

Empezamos con los datos.

readLines("wine.data", n = 10)
 [1] "1,14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065"
 [2] "1,13.2,1.78,2.14,11.2,100,2.65,2.76,.26,1.28,4.38,1.05,3.4,1050" 
 [3] "1,13.16,2.36,2.67,18.6,101,2.8,3.24,.3,2.81,5.68,1.03,3.17,1185" 
 [4] "1,14.37,1.95,2.5,16.8,113,3.85,3.49,.24,2.18,7.8,.86,3.45,1480"  
 [5] "1,13.24,2.59,2.87,21,118,2.8,2.69,.39,1.82,4.32,1.04,2.93,735"   
 [6] "1,14.2,1.76,2.45,15.2,112,3.27,3.39,.34,1.97,6.75,1.05,2.85,1450"
 [7] "1,14.39,1.87,2.45,14.6,96,2.5,2.52,.3,1.98,5.25,1.02,3.58,1290"  
 [8] "1,14.06,2.15,2.61,17.6,121,2.6,2.51,.31,1.25,5.05,1.06,3.58,1295"
 [9] "1,14.83,1.64,2.17,14,97,2.8,2.98,.29,1.98,5.2,1.08,2.85,1045"    
[10] "1,13.86,1.35,2.27,16,98,2.98,3.15,.22,1.85,7.22,1.01,3.55,1045"  

El archivo de datos parece ser una tabla de datos rectangular, con columnas separadas por comas. Entonces leer este archivo es fácil. El único inconveniente que tenemos es que nos faltan los nombres de cada columna.

Podemos usar read_table() para leer este archivo. Esta función está diseñada para leer tablas de datos, es decir, con estructura rectangular (renglones y columnas).

Para asegurarnos que los datos serán leídos de manera correcta, especificamos que el separador de las columnas es una coma (sep = “,”) y que no tenemos nombres de columna en nuestro archivo (header = FALSE). Asignamos el resultado al objeto vino.

vino <- read.table("wine.data", sep = ",", header = FALSE)

Veamos los datos.

vino
  V1    V2   V3   V4   V5  V6   V7   V8   V9  V10  V11  V12  V13  V14
1  1 14.23 1.71 2.43 15.6 127 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065
2  1 13.20 1.78 2.14 11.2 100 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050
3  1 13.16 2.36 2.67 18.6 101 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185
4  1 14.37 1.95 2.50 16.8 113 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480
5  1 13.24 2.59 2.87 21.0 118 2.80 2.69 0.39 1.82 4.32 1.04 2.93  735
 [ reached 'max' / getOption("max.print") -- omitted 173 rows ]

Tenemos 178 renglones y 14 columnas. Aunque aún no sabemos que contienen.

Veamos si el archivo wine.names tiene respuestas.

readLines("wine.names", n = 10)
 [1] "1. Title of Database: Wine recognition data"                                    
 [2] "\tUpdated Sept 21, 1998 by C.Blake : Added attribute information"               
 [3] ""                                                                               
 [4] "2. Sources:"                                                                    
 [5] "   (a) Forina, M. et al, PARVUS - An Extendible Package for Data"               
 [6] "       Exploration, Classification and Correlation. Institute of Pharmaceutical"
 [7] "       and Food Analysis and Technologies, Via Brigata Salerno, "               
 [8] "       16147 Genoa, Italy."                                                     
 [9] ""                                                                               
[10] "   (b) Stefan Aeberhard, email: stefan@coral.cs.jcu.edu.au"                     

Parece ser un archivo de texto común y corriente, pero con una extensión inusual. Podemos crear una copia de este archivo con la extensión a .txt con file.copy() para leerlo fácilmente en bloc de notas o cualquier aplicación similar. Después, usamos file.show() para darle una lectura.

file.copy(from = "wine.names", to = "wine_names.txt")
[1] FALSE
file.show("wine_names.txt")

A partir de lo que este documento explica, descubrimos que nuestros datos corresponden a trece características químicas de tres tipos de vinos. Esto quiere decir que una de las columnas de nuestros datos indica el tipo de vino y las otras trece son sus características.

Aunque es probable que la primera columna de nuestros datos sea la variable con el tipo de vino, usamos summary() para asegurarnos

summary(vino)
       V1              V2              V3              V4       
 Min.   :1.000   Min.   :11.03   Min.   :0.740   Min.   :1.360  
 1st Qu.:1.000   1st Qu.:12.36   1st Qu.:1.603   1st Qu.:2.210  
 Median :2.000   Median :13.05   Median :1.865   Median :2.360  
 Mean   :1.938   Mean   :13.00   Mean   :2.336   Mean   :2.367  
 3rd Qu.:3.000   3rd Qu.:13.68   3rd Qu.:3.083   3rd Qu.:2.558  
       V5              V6               V7              V8       
 Min.   :10.60   Min.   : 70.00   Min.   :0.980   Min.   :0.340  
 1st Qu.:17.20   1st Qu.: 88.00   1st Qu.:1.742   1st Qu.:1.205  
 Median :19.50   Median : 98.00   Median :2.355   Median :2.135  
 Mean   :19.49   Mean   : 99.74   Mean   :2.295   Mean   :2.029  
 3rd Qu.:21.50   3rd Qu.:107.00   3rd Qu.:2.800   3rd Qu.:2.875  
       V9              V10             V11              V12        
 Min.   :0.1300   Min.   :0.410   Min.   : 1.280   Min.   :0.4800  
 1st Qu.:0.2700   1st Qu.:1.250   1st Qu.: 3.220   1st Qu.:0.7825  
 Median :0.3400   Median :1.555   Median : 4.690   Median :0.9650  
 Mean   :0.3619   Mean   :1.591   Mean   : 5.058   Mean   :0.9574  
 3rd Qu.:0.4375   3rd Qu.:1.950   3rd Qu.: 6.200   3rd Qu.:1.1200  
      V13             V14        
 Min.   :1.270   Min.   : 278.0  
 1st Qu.:1.938   1st Qu.: 500.5  
 Median :2.780   Median : 673.5  
 Mean   :2.612   Mean   : 746.9  
 3rd Qu.:3.170   3rd Qu.: 985.0  
 [ reached getOption("max.print") -- omitted 1 row ]

Como la V1 es la única con un valor mínimo de 1 y máximo de 3, es seguro que corresponde al tipo de vino, así que podemos renombrarla para facilitar el análisis.

El resto de nombres de columna los podemos obtener del archivo con la información de los datos, haciendo algo de manipulación con expresiones regulares (regex), a través de gsub().

nombres <- 
  readLines("wine_names.txt")[58:70] %>% 
  gsub("[[:cntrl:]].*\\)", "", .) %>% 
  trimws() %>% 
  tolower() %>% 
  gsub(" |/", "_", .) %>% 
  # Agregamos el nombre "tipo", para nuestra primera columna con los tipos de vino

    c("tipo", .)

Ahora podemos cambiar los nombres de nuestros datos.

names(vino) <- nombres 

Por último, cambiamos el tipo de dato de la columna tipo a factor usando la función mutate_at() de dplyr, para poder hacer clasificaciones. De otro modo, como esta columna tiene valores numéricos, podemos tener conflictos más adelante.

vino <- vino %>% 
  mutate_at("tipo", factor)

Ahora sí, empecemos a crear árboles de clasificación.

Creando un sets de entrenamiento y prueba

Necesitamos un set de entrenamiento para generar un modelo predictivo, y un set de prueba, para comprobar la eficacia de este modelo para hacer predicciones correctas.

Usamos la función sample_frac() de dplyr para obtener un subconjunto de nuestros datos, que consiste en 70% del total de ellos. Usamos también set.seed() para que este ejemplo sea reproducible.

set.seed(1649)
vino_entrenamiento <- sample_frac(vino, .7)

Con setdiff() de dplyr, obtenemos el subconjunto de datos complementario al de entrenamiento para nuestro set de prueba, esto es, el 30% restante.

vino_prueba <- setdiff(vino, vino_entrenamiento)

Entrenando nuestro modelo

Usamos la función rpart de rpart para entrenar nuestro modelo. Esta función nos pide una formula para especificar la variable objetivo de la clasificación. La formula que usaremos es tipo ~ ., la cual expresa que intentaremos clasificar tipo usando a todas las demás variables como predictoras.

En este primer intento no ajustaremos ningún otro parámetro.

arbol_1 <- rpart(formula = tipo ~ ., data = vino_entrenamiento)

Es hora de ver cómo nos ha ido con nuestro modelo

Evaluando nuestro modelo

Del entrenamiento de nuestro modelo obtenemos el siguiente resultado.

arbol_1
n= 125 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 125 75 2 (0.35200000 0.40000000 0.24800000)  
  2) proline>=755 48  6 1 (0.87500000 0.04166667 0.08333333)  
    4) flavanoids>=2.35 41  1 1 (0.97560976 0.02439024 0.00000000) *
    5) flavanoids< 2.35 7  3 3 (0.28571429 0.14285714 0.57142857) *
  3) proline< 755 77 29 2 (0.02597403 0.62337662 0.35064935)  
    6) flavanoids>=1.265 51  4 2 (0.03921569 0.92156863 0.03921569) *
    7) flavanoids< 1.265 26  1 3 (0.00000000 0.03846154 0.96153846) *
rpart.plot(arbol_1)

En estos gráficos, cada uno de los rectángulos representa un nodo de nuestro árbol, con su regla de clasificación.

Cada nodo está coloreado de acuerdo a la categoría mayoritaria entre los datos que agrupa. Esta es la categoría que ha predicho el modelo para ese grupo.

Dentro del rectángulo de cada nodo se nos muestra qué proporción de casos pertenecen a cada categoría y la proporción del total de datos que han sido agrupados allí. Por ejemplo, el rectángulo en el extremo inferior izquierdo de la gráfica tiene 94% de casos en el tipo 1, y 4% en los tipos 2 y 3, que representan 39% de todos los datos.

Estas proporciones nos dan una idea de la precisión de nuestro modelo al hacer predicciones. De este modo, las reglas que conducen al rectángulo que acabamos de mencionar nos dan un 92% de clasificaciones correctas. En contraste, el tercer rectángulo, de izquierda a derecha, de color gris, tuvo sólo 62% de clasificaciones correctas.

Además, podemos sentirnos contentos de que dos de las hojas de nuestro árbol de clasificación han logrado un 100% de clasificaciones correctas, para los vinos de tipo 2 y 3.

Pero, por supuesto, necesitamos ser más sistemáticos para indagar qué tan bien hace predicciones nuestro modelo.

Usamos la función precict() con nuestro set de prueba para generar un vector con los valores predichos por el modelo que hemos entrenado, especificamos el parámetro type = “class”.

argumento type = "class para

prediccion_1 <- predict(arbol_1, newdata = vino_prueba, type = "class")

Cruzamos la predicción con los datos reales de nuestro set de prueba para generar una matriz de confusión, usando confusionMatrix() de caret.

confusionMatrix(prediccion_1, vino_prueba[["tipo"]])
Confusion Matrix and Statistics

          Reference
Prediction  1  2  3
         1 15  0  0
         2  0 15  3
         3  0  6 14

Overall Statistics
                                         
               Accuracy : 0.8302         
                 95% CI : (0.702, 0.9193)
    No Information Rate : 0.3962         
    P-Value [Acc > NIR] : 1.106e-10      
                                         
                  Kappa : 0.7444         
                                         
 Mcnemar's Test P-Value : NA             

Statistics by Class:

                     Class: 1 Class: 2 Class: 3
Sensitivity             1.000   0.7143   0.8235
Specificity             1.000   0.9062   0.8333
Pos Pred Value          1.000   0.8333   0.7000
Neg Pred Value          1.000   0.8286   0.9091
Prevalence              0.283   0.3962   0.3208
Detection Rate          0.283   0.2830   0.2642
Detection Prevalence    0.283   0.3396   0.3774
Balanced Accuracy       1.000   0.8103   0.8284

Nada mal. Tenemos una precisión (accuracy), Kappa y otros estadísticos con buenos valores.

Sin embargo, no hemos terminado. Este árbol ha predicciones a partir de los datos de entrenamiento que hemos proporcionado. ¿Recuerdas que el algoritmo busca la mejor separación para crear grupos? Si nuestros datos cambian, la variable que hace la mejore separación también puede cambiar. Y por lo tanto, los grupos que resulten de esta separación, serán distintos, resultando en un modelo que puede ser muy distinto al que hemos obtenido.

Generamos un segundo árbol, usando sets de entrenamiento y prueba diferentes.

set.seed(7439)
vino_entrenamiento_2 <- sample_frac(vino, .7)

vino_prueba_2 <- setdiff(vino, vino_entrenamiento)

arbol_2 <- rpart(formula = tipo ~ ., data = vino_entrenamiento_2)

prediccion_2 <- predict(arbol_2, newdata = vino_prueba_2, type = "class")

Veamos los resultados.

rpart.plot(arbol_2)

Matriz de Confusión

confusionMatrix(prediccion_2, vino_prueba_2[["tipo"]])
Confusion Matrix and Statistics

          Reference
Prediction  1  2  3
         1 14  0  0
         2  1 21  0
         3  0  0 17

Overall Statistics
                                          
               Accuracy : 0.9811          
                 95% CI : (0.8993, 0.9995)
    No Information Rate : 0.3962          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.9713          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3
Sensitivity            0.9333   1.0000   1.0000
Specificity            1.0000   0.9688   1.0000
Pos Pred Value         1.0000   0.9545   1.0000
Neg Pred Value         0.9744   1.0000   1.0000
Prevalence             0.2830   0.3962   0.3208
Detection Rate         0.2642   0.3962   0.3208
Detection Prevalence   0.2642   0.4151   0.3208
Balanced Accuracy      0.9667   0.9844   1.0000