All State Loss Prediction
AllState_Kernel
All State Loss Prediction
install.packages(“rmarkdown”)
title : All State Loss Prediction - Multiple Linear Regression on select fields output: html_document: toc: true theme: united — # Required Libraries
- rmarkdown is been used to generate a record of the work - caret is the core library used to enable training and generation of predictions
library(rmarkdown)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
Input Files
train.csv - file containing 132 columns including loss. test.csv - file containing 131 columns. Loss needs to be predicted submit.csv- sample submission file. Contains id from test.csv and predicted loss
Our objective is to - use the train data to build the algorithm - apply the model over the test data - predict the loss for each case #################### DATA ###################################
Reading the train.csv file - id - Categorical Variables from 1 through 116 - Continous variables from 1 through 14 - loss
#allstate=read.csv("./train.csv", nrows=10000)
allstate=read.csv("./train.csv")
head(allstate)
## id cat1 cat2 cat3 cat4 cat5 cat6 cat7 cat8 cat9 cat10 cat11 cat12 cat13
## 1 1 A B A B A A A A B A B A A
## 2 2 A B A A A A A A B B A A A
## 3 5 A B A A B A A A B B B B B
## 4 10 B B A B A A A A B A A A A
## 5 11 A B A B A A A A B B A B A
## 6 13 A B A A A A A A B A A A A
## cat14 cat15 cat16 cat17 cat18 cat19 cat20 cat21 cat22 cat23 cat24 cat25
## 1 A A A A A A A A A B A A
## 2 A A A A A A A A A A A A
## 3 A A A A A A A A A A A A
## 4 A A A A A A A A A B A A
## 5 A A A A A A A A A B A A
## 6 A A A A A A A A A A A A
## cat26 cat27 cat28 cat29 cat30 cat31 cat32 cat33 cat34 cat35 cat36 cat37
## 1 A A A A A A A A A A A A
## 2 A A A A A A A A A A A A
## 3 A A A A A A A A A A B A
## 4 A A A A A A A A A A A A
## 5 A A A A A A A A A A A A
## 6 A A A A A A A A A A A A
## cat38 cat39 cat40 cat41 cat42 cat43 cat44 cat45 cat46 cat47 cat48 cat49
## 1 A A A A A A A A A A A A
## 2 A A A A A A A A A A A A
## 3 A A A A A A A A A A A A
## 4 A A A A A A A A A A A A
## 5 A A A A A A A A A A A A
## 6 A A A A A A A A A A A A
## cat50 cat51 cat52 cat53 cat54 cat55 cat56 cat57 cat58 cat59 cat60 cat61
## 1 A A A A A A A A A A A A
## 2 A A A A A A A A A A A A
## 3 A A A A A A A A A A A A
## 4 A A A A A A A A A A A A
## 5 A A A A A A A A A A A A
## 6 A A A A A A A A A A A A
## cat62 cat63 cat64 cat65 cat66 cat67 cat68 cat69 cat70 cat71 cat72 cat73
## 1 A A A A A A A A A A A A
## 2 A A A A A A A A A A A A
## 3 A A A A A A A A A A A A
## 4 A A A A A A A A A A A B
## 5 A A A A A A A A A A B A
## 6 A A A A A A A A A A B A
## cat74 cat75 cat76 cat77 cat78 cat79 cat80 cat81 cat82 cat83 cat84 cat85
## 1 A B A D B B D D B D C B
## 2 A A A D B B D D A B C B
## 3 A A A D B B B D B D C B
## 4 A A A D B B D D D B C B
## 5 A A A D B D B D B B C B
## 6 A A A D B D B D B B C B
## cat86 cat87 cat88 cat89 cat90 cat91 cat92 cat93 cat94 cat95 cat96 cat97
## 1 D B A A A A A D B C E A
## 2 D B A A A A A D D C E E
## 3 B B A A A A A D D C E E
## 4 D B A A A A A D D C E E
## 5 B C A A A B H D B D E E
## 6 B B A A A A A D D D E C
## cat98 cat99 cat100 cat101 cat102 cat103 cat104 cat105 cat106 cat107
## 1 C T B G A A I E G J
## 2 D T L F A A E E I K
## 3 A D L O A B E F H F
## 4 D T I D A A E E I K
## 5 A P F J A A D E K G
## 6 A P J D A A E E H F
## cat108 cat109 cat110 cat111 cat112 cat113 cat114 cat115 cat116 cont1
## 1 G BU BC C AS S A O LB 0.726300
## 2 K BI CQ A AV BM A O DP 0.330514
## 3 A AB DK A C AF A I GK 0.261841
## 4 K BI CS C N AE A O DJ 0.321594
## 5 B H C C Y BM A K CK 0.273204
## 6 B BI CS A AS AE A K DJ 0.546670
## cont2 cont3 cont4 cont5 cont6 cont7 cont8 cont9
## 1 0.245921 0.187583 0.789639 0.310061 0.718367 0.335060 0.30260 0.67135
## 2 0.737068 0.592681 0.614134 0.885834 0.438917 0.436585 0.60087 0.35127
## 3 0.358319 0.484196 0.236924 0.397069 0.289648 0.315545 0.27320 0.26076
## 4 0.555782 0.527991 0.373816 0.422268 0.440945 0.391128 0.31796 0.32128
## 5 0.159990 0.527991 0.473202 0.704268 0.178193 0.247408 0.24564 0.22089
## 6 0.681761 0.634224 0.373816 0.302678 0.364464 0.401162 0.26847 0.46226
## cont10 cont11 cont12 cont13 cont14 loss
## 1 0.83510 0.569745 0.594646 0.822493 0.714843 2213.18
## 2 0.43919 0.338312 0.366307 0.611431 0.304496 1283.60
## 3 0.32446 0.381398 0.373424 0.195709 0.774425 3005.09
## 4 0.44467 0.327915 0.321570 0.605077 0.602642 939.85
## 5 0.21230 0.204687 0.202213 0.246011 0.432606 2763.85
## 6 0.50556 0.366788 0.359249 0.345247 0.726792 5142.87
Generating summary statistics on all fields
summary(allstate)
## id cat1 cat2 cat3 cat4 cat5
## Min. : 1 A:141550 A:106721 A:177993 A:128395 A:123737
## 1st Qu.:147748 B: 46768 B: 81597 B: 10325 B: 59923 B: 64581
## Median :294540
## Mean :294136
## 3rd Qu.:440680
## Max. :587633
##
## cat6 cat7 cat8 cat9 cat10 cat11
## A:131693 A:183744 A:177274 A:113122 A:160213 A:168186
## B: 56625 B: 4574 B: 11044 B: 75196 B: 28105 B: 20132
##
##
##
##
##
## cat12 cat13 cat14 cat15 cat16 cat17
## A:159825 A:168851 A:186041 A:188284 A:181843 A:187009
## B: 28493 B: 19467 B: 2277 B: 34 B: 6475 B: 1309
##
##
##
##
##
## cat18 cat19 cat20 cat21 cat22 cat23
## A:187331 A:186510 A:188114 A:187905 A:188275 A:157445
## B: 987 B: 1808 B: 204 B: 413 B: 43 B: 30873
##
##
##
##
##
## cat24 cat25 cat26 cat27 cat28 cat29
## A:181977 A:169969 A:177119 A:168250 A:180938 A:184593
## B: 6341 B: 18349 B: 11199 B: 20068 B: 7380 B: 3725
##
##
##
##
##
## cat30 cat31 cat32 cat33 cat34 cat35
## A:184760 A:182980 A:187107 A:187361 A:187734 A:188105
## B: 3558 B: 5338 B: 1211 B: 957 B: 584 B: 213
##
##
##
##
##
## cat36 cat37 cat38 cat39 cat40 cat41
## A:156313 A:165729 A:169323 A:183393 A:180119 A:181177
## B: 32005 B: 22589 B: 18995 B: 4925 B: 8199 B: 7141
##
##
##
##
##
## cat42 cat43 cat44 cat45 cat46 cat47
## A:186623 A:184110 A:172716 A:183991 A:187436 A:187617
## B: 1695 B: 4208 B: 15602 B: 4327 B: 882 B: 701
##
##
##
##
##
## cat48 cat49 cat50 cat51 cat52 cat53
## A:188049 A:179127 A:137611 A:187071 A:179505 A:172949
## B: 269 B: 9191 B: 50707 B: 1247 B: 8813 B: 15369
##
##
##
##
##
## cat54 cat55 cat56 cat57 cat58 cat59
## A:183762 A:188173 A:188136 A:185296 A:188079 A:188018
## B: 4556 B: 145 B: 182 B: 3022 B: 239 B: 300
##
##
##
##
##
## cat60 cat61 cat62 cat63 cat64 cat65
## A:187872 A:187596 A:188273 A:188239 A:188271 A:186056
## B: 446 B: 722 B: 45 B: 79 B: 47 B: 2262
##
##
##
##
##
## cat66 cat67 cat68 cat69 cat70 cat71
## A:179982 A:187626 A:188176 A:188011 A:188295 A:178646
## B: 8336 B: 692 B: 142 B: 307 B: 23 B: 9672
##
##
##
##
##
## cat72 cat73 cat74 cat75 cat76 cat77
## A:118322 A:154275 A:184731 A:154307 A:181347 A: 49
## B: 69996 B: 34017 B: 3561 B: 34010 B: 6183 B: 358
## C: 26 C: 26 C: 1 C: 788 C: 408
## D:187503
##
##
##
## cat78 cat79 cat80 cat81 cat82 cat83
## A: 788 A: 7064 A: 783 A: 788 A: 19322 A: 26038
## B:186526 B:152929 B: 46538 B: 24132 B:147536 B:141534
## C: 645 C: 1668 C: 3492 C: 9013 C: 2655 C: 4958
## D: 359 D: 26657 D:137505 D:154385 D: 18805 D: 15788
##
##
##
## cat84 cat85 cat86 cat87 cat88 cat89
## A: 29450 A: 788 A: 1589 A: 788 A:168926 A :183744
## B: 431 B:186005 B:103852 B:166992 B: 7 B : 4312
## C:154939 C: 1011 C: 10290 C: 8819 D: 19302 C : 220
## D: 3498 D: 514 D: 72587 D: 11719 E: 83 D : 33
## E : 5
## I : 2
## (Other): 2
## cat90 cat91 cat92 cat93 cat94 cat95
## A:177993 A :111028 A:124689 A: 432 A: 738 A: 3736
## B: 9515 B : 42630 B: 628 B: 1133 B: 51710 B: 109
## C: 728 G : 26734 C: 62 C: 35788 C: 13623 C:87531
## D: 70 C : 6400 D: 11 D:150237 D:121642 D:79525
## E: 6 D : 1149 F: 1 E: 728 E: 91 E:17417
## F: 4 E : 254 H: 62901 F: 494
## G: 2 (Other): 123 I: 26 G: 20
## cat96 cat97 cat98 cat99 cat100
## E :174360 A:41970 A:105492 P :79455 F :42970
## D : 7922 B: 34 B: 542 T :72591 I :39933
## B : 2957 C:78127 C: 21485 R :10290 L :19961
## G : 2665 D: 3779 D: 50557 D : 8844 K :13817
## F : 343 E:47450 E: 10242 S : 7045 G :12935
## A : 35 F: 213 N : 2894 J :12027
## (Other): 36 G:16745 (Other): 7199 (Other):46675
## cat101 cat102 cat103 cat104
## A :106721 A :177274 A :123737 E :42925
## D : 17171 B : 5155 B : 33342 G :40660
## C : 16971 C : 4929 C : 16508 D :27611
## G : 10944 E : 482 D : 7806 F :19228
## F : 10139 D : 449 E : 4473 H :17187
## J : 7259 G : 15 F : 1528 K :14297
## (Other): 19113 (Other): 14 (Other): 924 (Other):26410
## cat105 cat106 cat107 cat108
## E :76493 G :47165 F :47310 B :65512
## F :62892 H :37713 G :28560 K :42435
## G :20613 F :36143 H :23461 G :21421
## D :12172 I :21433 J :22405 D :19160
## H :11258 J :18281 K :20236 F :10242
## I : 2941 E :13000 I :20066 A : 9299
## (Other): 1949 (Other):14583 (Other):26280 (Other):20249
## cat109 cat110 cat111 cat112
## BI :152918 CL :25305 A :128395 E :25148
## AB : 21933 EG :24654 C : 32401 AH :18639
## BU : 3142 CS :24592 E : 14682 AS :17669
## K : 2999 EB :21396 G : 7039 J :16222
## G : 1353 CO :17495 I : 3578 AF : 9368
## BQ : 1067 BT :16365 K : 1353 AN : 9138
## (Other): 4906 (Other):58511 (Other): 870 (Other):92134
## cat113 cat114 cat115 cat116
## BM :26191 A :131693 K :43866 HK : 21061
## AE :22030 C : 16793 O :26813 DJ : 20244
## L :13058 E : 16475 J :23895 CK : 10162
## AX :12661 J : 8199 N :22438 DP : 9202
## Y :11374 F : 7905 P :21538 GS : 8736
## K : 7738 N : 2455 L :16125 CR : 6862
## (Other):95266 (Other): 4798 (Other):33643 (Other):112051
## cont1 cont2 cont3 cont4
## Min. :0.000016 Min. :0.001149 Min. :0.002634 Min. :0.1769
## 1st Qu.:0.346090 1st Qu.:0.358319 1st Qu.:0.336963 1st Qu.:0.3274
## Median :0.475784 Median :0.555782 Median :0.527991 Median :0.4529
## Mean :0.493861 Mean :0.507188 Mean :0.498918 Mean :0.4918
## 3rd Qu.:0.623912 3rd Qu.:0.681761 3rd Qu.:0.634224 3rd Qu.:0.6521
## Max. :0.984975 Max. :0.862654 Max. :0.944251 Max. :0.9543
##
## cont5 cont6 cont7 cont8
## Min. :0.2811 Min. :0.01268 Min. :0.0695 Min. :0.2369
## 1st Qu.:0.2811 1st Qu.:0.33610 1st Qu.:0.3502 1st Qu.:0.3128
## Median :0.4223 Median :0.44094 Median :0.4383 Median :0.4411
## Mean :0.4874 Mean :0.49094 Mean :0.4850 Mean :0.4864
## 3rd Qu.:0.6433 3rd Qu.:0.65502 3rd Qu.:0.5910 3rd Qu.:0.6236
## Max. :0.9837 Max. :0.99716 Max. :1.0000 Max. :0.9802
##
## cont9 cont10 cont11 cont12
## Min. :0.00008 Min. :0.0000 Min. :0.03532 Min. :0.03623
## 1st Qu.:0.35897 1st Qu.:0.3646 1st Qu.:0.31096 1st Qu.:0.31166
## Median :0.44145 Median :0.4612 Median :0.45720 Median :0.46229
## Mean :0.48551 Mean :0.4981 Mean :0.49351 Mean :0.49315
## 3rd Qu.:0.56682 3rd Qu.:0.6146 3rd Qu.:0.67892 3rd Qu.:0.67576
## Max. :0.99540 Max. :0.9950 Max. :0.99874 Max. :0.99848
##
## cont13 cont14 loss
## Min. :0.000228 Min. :0.1797 Min. : 0.67
## 1st Qu.:0.315758 1st Qu.:0.2946 1st Qu.: 1204.46
## Median :0.363547 Median :0.4074 Median : 2115.57
## Mean :0.493138 Mean :0.4957 Mean : 3037.34
## 3rd Qu.:0.689974 3rd Qu.:0.7246 3rd Qu.: 3864.05
## Max. :0.988494 Max. :0.8448 Max. :121012.25
##
Generating plots based on all categorical fields
#histogram(loss, data=allstateFile[,c(132,2:118)])
# Multiple histograms
#par(mfrow=c(3, 3))
#colnames <- dimnames(allstateFile$loss)[[2]]
# hist(allstateFile[,i], xlim=c(0, 3500), breaks=seq(0, 3500, 100), main=colnames[i], #probability=TRUE, col="gray", border="white")
#}
hist(allstate$loss)
We see that the loss itself is heavily skewed. We need to normalize this to get better prediction. Will Try doing a log on the loss to see if its better.
allstate$logloss <- log(allstate$loss)
hist(allstate$logloss)
carvars = paste("cat", 1:116, sep="")
for( catvar in 1:8){
catvar <- paste("cat", catvar, sep="")
ggplot(allstate, aes_string("loss", fill=catvar)) + geom_histogram()
}
PREDICTION
Algorithm - Multiple Linear Regression on select fields with 10 fold cross validation
Paritioning training data
We are going to take 80% of the train data and use it for our model training. The remaining 20% of the data will be used to determine how well our model worked.
#featurePlot(x=allstate[,3], y=allstate[,1], plot='density')
#allstate=allstateFile
set.seed(1234)
# define an 80%/20% train/test split of the dataset
split=0.80
trainIndex <- createDataPartition(allstate$id, p=split, list=FALSE)
data_train <- allstate[ trainIndex,]
data_test <- allstate[-trainIndex,]
#str(data_train)
10 fold Cross Validated Model
We will do a cross validated 10 fold process to train the model. Method choosen is lm Metric used to select the model is the least RMSE
catfactors <- paste("cat", 1:116, sep="")
contfactors <-paste("cont", 1:14, sep="")
formula = reformulate(termlabels = c(catfactors,contfactors), response = 'logloss')
#modelFit <- train( formula,data=allstate, method="rpart" )
#varImp(modelFit)
rpart variable importance
only 20 most important variables shown (out of 1037)
Overall
cat80D 100.00 cat80B 99.75 cat12B 78.30 cat79D 75.14 cat79B 59.11 cat10B 18.35 cat1B 18.27 cat81D 16.09 cat81B 14.15
#Taking only important variables
catfactors <- c("cat80","cat12","cat79", "cat10", "cat1", "cat81")
formula = reformulate(termlabels = c(catfactors,contfactors), response = 'logloss')
ControlParamteres <- trainControl(method = "cv",
number = 10,
savePredictions = TRUE,
classProbs = TRUE,
verboseIter = TRUE
)
parametersGrid <- expand.grid(nrounds=100,
lambda=.5,
alpha=.5,
eta = 0.1
)
model.xgboost <- train(formula, data = data_train,method = "xgbLinear", trControl = ControlParamteres, tuneGrid=parametersGrid)
## Loading required package: xgboost
## Warning in train.default(x, y, weights = w, ...): cannnot compute class
## probabilities for regression
## + Fold01: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold01: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold02: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold02: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold03: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold03: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold04: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold04: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold05: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold05: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold06: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold06: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold07: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold07: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold08: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold08: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold09: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold09: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## + Fold10: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## - Fold10: nrounds=100, lambda=0.5, alpha=0.5, eta=0.1
## Aggregating results
## Fitting final model on full training set
warnings()
## NULL
summary(model.xgboost)
## Length Class Mode
## handle 1 xgb.Booster.handle externalptr
## raw 422143 -none- raw
## niter 1 -none- numeric
## call 5 -none- call
## params 4 -none- list
## callbacks 1 -none- list
## xNames 26 -none- character
## problemType 1 -none- character
## tuneValue 4 data.frame list
## obsLevels 1 -none- logical
## param 0 -none- list
Validating our model
Validation of the trained model on the test data(20% of the train data)
x_test <- data_test[c(catfactors,contfactors)]
y_test <- data_test[,"logloss"]
#plot(x_test,y_test)
Compute predictions based on the trained model on the 20% of test validation data
predictions <- predict(model.xgboost, x_test)
str(predictions)
## num [1:37662] 7.1 8.06 7.31 8.73 8.93 ...
head(y_test)
## [1] 6.845720 7.924380 7.031936 8.184723 9.237975 8.693787
str(predictions)
## num [1:37662] 7.1 8.06 7.31 8.73 8.93 ...
hist(predictions)
Computing RMSE and R2
caret::RMSE(pred = predictions, obs = y_test)
## [1] 0.6188222
caret::R2(pred = predictions, obs = y_test)
## [1] 0.4221604
#defaultsummary(data_train)
Predicting Loss
Application of the trained model on the data set that needs prediction.
testFile=read.csv("/users/smithamathew/kaggle/test.csv")
out_test <-testFile
id <-out_test$id
str(id)
## int [1:125546] 4 6 9 12 15 17 21 28 32 43 ...
logloss <- predict(model.xgboost, out_test)
loss <- exp(logloss)
head(loss)
## [1] 1328.056 2035.727 6157.791 4315.195 1413.001 1512.084
hist(loss)
out_file=cbind(id,loss)
head(out_file)
## id loss
## [1,] 4 1328.056
## [2,] 6 2035.727
## [3,] 9 6157.791
## [4,] 12 4315.195
## [5,] 15 1413.001
## [6,] 17 1512.084
options(scipen=999)
write.csv(file="/users/smithamathew/kaggle/submit.csv",out_file,row.names = FALSE)