1+ {-# LANGUAGE DataKinds #-}
12{-# LANGUAGE NumericUnderscores #-}
23{-# LANGUAGE OverloadedStrings #-}
34{-# LANGUAGE TemplateHaskell #-}
45{-# LANGUAGE TypeApplications #-}
6+ {-# LANGUAGE TypeFamilies #-}
7+ {-# LANGUAGE TypeOperators #-}
58
9+ import Data.Char
610import qualified Data.Text as T
711import qualified DataFrame as D
12+ import DataFrame.DecisionTree
813import qualified DataFrame.Functions as F
14+ import DataFrame.Operators
15+ import qualified DataFrame.Typed as DT
16+ import System.Random
917
10- import Data.Char
11- import DataFrame.DecisionTree
12- import DataFrame.Operators hiding (name )
18+ $ ( DT. deriveSchemaFromCsvFileWith
19+ D. defaultReadOptions{D. safeRead = True }
20+ " TrainSchema"
21+ " ./data/titanic/train.csv"
22+ )
23+ $ ( DT. deriveSchemaFromCsvFileWith
24+ D. defaultReadOptions{D. safeRead = True }
25+ " TestSchema"
26+ " ./data/titanic/test.csv"
27+ )
1328
14- import System.Random
29+ -- Survived is Maybe Int (safeRead = True); prediction is Int (model output).
30+ type RawPredSchema =
31+ '[DT. Column " Survived" (Maybe Int ), DT. Column " prediction" Int ]
1532
16- $ (F. declareColumnsFromCsvFile " ./data/titanic/train.csv" )
33+ prediction :: D. Expr Int
34+ prediction = F. col @ Int " prediction"
1735
1836main :: IO ()
1937main = do
20- train <- D. readCsv " ./data/titanic/train.csv"
21- test <- D. readCsv " ./data/titanic/test.csv"
38+ rawTrain <- D. readCsv " ./data/titanic/train.csv"
39+ rawTest <- D. readCsv " ./data/titanic/test.csv"
2240
23- -- Apply the same transformations to training and test.
24- let combined =
25- (train <> test)
26- |> D. deriveMany
27- [ " Ticket" .= F. lift (T. filter isAlpha) ticket
28- , " Name" .= F. match " \\ s*([A-Za-z]+)\\ ." name
29- , " Cabin" .= F. whenPresent (T. take 1 ) cabin
30- ]
31- |> D. renameMany
32- [ (F. name name, " title" )
33- , (F. name cabin, " cabin_prefix" )
34- , (F. name pclass, " passenger_class" )
35- , (F. name sibsp, " number_of_siblings_and_spouses" )
36- , (F. name parch, " number_of_parents_and_children" )
37- ]
38- print combined
41+ train <-
42+ maybe (fail " train.csv schema mismatch" ) pure (DT. freeze @ TrainSchema rawTrain)
43+ test <-
44+ maybe (fail " test.csv schema mismatch" ) pure (DT. freeze @ TestSchema rawTest)
3945
40- let (train', validation) =
41- D. take
42- (D. nRows train)
43- combined
44- |> D. filterJust (F. name survived)
45- |> D. randomSplit (mkStdGen 4232 ) 0.7
46- -- Split the test out again.
47- test' =
48- D. drop
49- (D. nRows train)
50- combined
46+ let (trainDf, validDf) =
47+ D. randomSplit (mkStdGen 4232 ) 0.7 (DT. thaw (clean train))
48+ testDf = DT. thaw (clean test)
5149
5250 model =
5351 fitDecisionTree
@@ -61,70 +59,84 @@ main = do
6159 { complexityPenalty = 0.1
6260 , maxExprDepth = 3
6361 , disallowedCombinations =
64- [ (F. name age, F. name fare )
62+ [ (" Age " , " Fare " )
6563 , (" passenger_class" , " number_of_siblings_and_spouses" )
6664 , (" passenger_class" , " number_of_parents_and_children" )
6765 ]
6866 }
6967 }
7068 )
71- survived -- Label to predict
72- ( train'
73- |> D. exclude [F. name passengerid]
74- )
69+ (F. fromMaybe 0 (F. col @ (Maybe Int ) " Survived" ))
70+ (trainDf |> D. exclude [" PassengerId" ])
7571
7672 print model
7773
7874 putStrLn " Training accuracy: "
79- print $
80- computeAccuracy
81- (train' |> D. derive (F. name prediction) model)
75+ print $ computeAccuracy (trainDf |> D. derive (F. name prediction) model)
8276
8377 putStrLn " Validation accuracy: "
84- print $
85- computeAccuracy
86- ( validation
87- |> D. derive (F. name prediction) model
88- )
78+ print $ computeAccuracy (validDf |> D. derive (F. name prediction) model)
8979
90- let predictions = D. derive (F. name survived) model test'
9180 D. writeCsv
9281 " ./predictions.csv"
93- (predictions |> D. select [F. name passengerid, F. name survived])
82+ ( testDf
83+ |> D. derive " Survived" model
84+ |> D. select [" PassengerId" , " Survived" ]
85+ )
9486
95- prediction :: D. Expr Int
96- prediction = F. col @ Int " prediction"
87+ clean ::
88+ ( DT. AssertPresent " Ticket" cols
89+ , DT. Lookup " Ticket" cols ~ Maybe T. Text
90+ , DT. AssertPresent " Name" cols
91+ , DT. Lookup " Name" cols ~ Maybe T. Text
92+ , DT. AssertPresent " Cabin" cols
93+ , DT. Lookup " Cabin" cols ~ Maybe T. Text
94+ ) =>
95+ DT. TypedDataFrame cols ->
96+ DT. TypedDataFrame
97+ ( DT. RenameManyInSchema
98+ '[ '(" Name" , " title" )
99+ , '(" Cabin" , " cabin_prefix" )
100+ , '(" Pclass" , " passenger_class" )
101+ , '(" SibSp" , " number_of_siblings_and_spouses" )
102+ , '(" Parch" , " number_of_parents_and_children" )
103+ ]
104+ cols
105+ )
106+ clean tdf =
107+ tdf
108+ |> DT. replaceColumn @ " Ticket" (DT. nullLift (T. filter isAlpha) (DT. col @ " Ticket" ))
109+ |> DT. replaceColumn @ " Name" (DT. nullLift extractTitle (DT. col @ " Name" ))
110+ |> DT. replaceColumn @ " Cabin" (DT. nullLift (T. take 1 ) (DT. col @ " Cabin" ))
111+ |> DT. renameMany
112+ @ '[ '(" Name" , " title" )
113+ , '(" Cabin" , " cabin_prefix" )
114+ , '(" Pclass" , " passenger_class" )
115+ , '(" SibSp" , " number_of_siblings_and_spouses" )
116+ , '(" Parch" , " number_of_parents_and_children" )
117+ ]
118+
119+ -- | Extract title (e.g. "Mr", "Mrs") from a full Titanic passenger name.
120+ extractTitle :: T. Text -> T. Text
121+ extractTitle name =
122+ case filter (T. isSuffixOf " ." ) (T. words name) of
123+ (w : _) -> T. dropEnd 1 w
124+ [] -> " "
97125
126+ {- | Compute binary classification accuracy from a DataFrame containing
127+ "Survived" and "prediction" columns.
128+ -}
98129computeAccuracy :: D. DataFrame -> Double
99130computeAccuracy df =
100- let
101- tp =
102- fromIntegral $
103- D. nRows
104- ( D. filterWhere
105- (survived .== F. lit (1 :: Int ) .&& prediction .== F. lit (1 :: Int ))
106- df
107- )
108- tn =
109- fromIntegral $
110- D. nRows
111- ( D. filterWhere
112- (survived .== F. lit (0 :: Int ) .&& prediction .== F. lit (0 :: Int ))
113- df
114- )
115- fp =
116- fromIntegral $
117- D. nRows
118- ( D. filterWhere
119- (survived .== F. lit (0 :: Int ) .&& prediction .== F. lit (1 :: Int ))
120- df
121- )
122- fn =
123- fromIntegral $
124- D. nRows
125- ( D. filterWhere
126- (survived .== F. lit (1 :: Int ) .&& prediction .== F. lit (0 :: Int ))
127- df
128- )
129- in
130- (tp + tn) / (tp + tn + fp + fn)
131+ let tdf =
132+ DT. impute @ " Survived" 0 $
133+ DT. unsafeFreeze @ RawPredSchema $
134+ df |> D. select [" Survived" , " prediction" ]
135+ survived = DT. col @ " Survived"
136+ pred = DT. col @ " prediction"
137+ count expr = fromIntegral (DT. nRows (DT. filterWhere expr tdf))
138+ tp = count ((survived DT. .==. DT. lit 1 ) DT. .&&. (pred DT. .==. DT. lit 1 ))
139+ tn = count ((survived DT. .==. DT. lit 0 ) DT. .&&. (pred DT. .==. DT. lit 0 ))
140+ fp = count ((survived DT. .==. DT. lit 0 ) DT. .&&. (pred DT. .==. DT. lit 1 ))
141+ fn = count ((survived DT. .==. DT. lit 1 ) DT. .&&. (pred DT. .==. DT. lit 0 ))
142+ in (tp + tn) / (tp + tn + fp + fn)
0 commit comments