Skip to content

Commit 3d0b4de

Browse files
committed
fix: Edge cases and lint errors from validity bit refactor.
1 parent 9b7daf5 commit 3d0b4de

22 files changed

Lines changed: 601 additions & 374 deletions

File tree

app/Synthesis.hs

Lines changed: 94 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,51 @@
1+
{-# LANGUAGE DataKinds #-}
12
{-# LANGUAGE NumericUnderscores #-}
23
{-# LANGUAGE OverloadedStrings #-}
34
{-# LANGUAGE TemplateHaskell #-}
45
{-# LANGUAGE TypeApplications #-}
6+
{-# LANGUAGE TypeFamilies #-}
7+
{-# LANGUAGE TypeOperators #-}
58

9+
import Data.Char
610
import qualified Data.Text as T
711
import qualified DataFrame as D
12+
import DataFrame.DecisionTree
813
import qualified DataFrame.Functions as F
14+
import DataFrame.Operators
15+
import qualified DataFrame.Typed as DT
16+
import System.Random
917

10-
import Data.Char
11-
import DataFrame.DecisionTree
12-
import DataFrame.Operators hiding (name)
18+
$( DT.deriveSchemaFromCsvFileWith
19+
D.defaultReadOptions{D.safeRead = True}
20+
"TrainSchema"
21+
"./data/titanic/train.csv"
22+
)
23+
$( DT.deriveSchemaFromCsvFileWith
24+
D.defaultReadOptions{D.safeRead = True}
25+
"TestSchema"
26+
"./data/titanic/test.csv"
27+
)
1328

14-
import System.Random
29+
-- Survived is Maybe Int (safeRead = True); prediction is Int (model output).
30+
type RawPredSchema =
31+
'[DT.Column "Survived" (Maybe Int), DT.Column "prediction" Int]
1532

16-
$(F.declareColumnsFromCsvFile "./data/titanic/train.csv")
33+
prediction :: D.Expr Int
34+
prediction = F.col @Int "prediction"
1735

1836
main :: IO ()
1937
main = do
20-
train <- D.readCsv "./data/titanic/train.csv"
21-
test <- D.readCsv "./data/titanic/test.csv"
38+
rawTrain <- D.readCsv "./data/titanic/train.csv"
39+
rawTest <- D.readCsv "./data/titanic/test.csv"
2240

23-
-- Apply the same transformations to training and test.
24-
let combined =
25-
(train <> test)
26-
|> D.deriveMany
27-
[ "Ticket" .= F.lift (T.filter isAlpha) ticket
28-
, "Name" .= F.match "\\s*([A-Za-z]+)\\." name
29-
, "Cabin" .= F.whenPresent (T.take 1) cabin
30-
]
31-
|> D.renameMany
32-
[ (F.name name, "title")
33-
, (F.name cabin, "cabin_prefix")
34-
, (F.name pclass, "passenger_class")
35-
, (F.name sibsp, "number_of_siblings_and_spouses")
36-
, (F.name parch, "number_of_parents_and_children")
37-
]
38-
print combined
41+
train <-
42+
maybe (fail "train.csv schema mismatch") pure (DT.freeze @TrainSchema rawTrain)
43+
test <-
44+
maybe (fail "test.csv schema mismatch") pure (DT.freeze @TestSchema rawTest)
3945

40-
let (train', validation) =
41-
D.take
42-
(D.nRows train)
43-
combined
44-
|> D.filterJust (F.name survived)
45-
|> D.randomSplit (mkStdGen 4232) 0.7
46-
-- Split the test out again.
47-
test' =
48-
D.drop
49-
(D.nRows train)
50-
combined
46+
let (trainDf, validDf) =
47+
D.randomSplit (mkStdGen 4232) 0.7 (DT.thaw (clean train))
48+
testDf = DT.thaw (clean test)
5149

5250
model =
5351
fitDecisionTree
@@ -61,70 +59,84 @@ main = do
6159
{ complexityPenalty = 0.1
6260
, maxExprDepth = 3
6361
, disallowedCombinations =
64-
[ (F.name age, F.name fare)
62+
[ ("Age", "Fare")
6563
, ("passenger_class", "number_of_siblings_and_spouses")
6664
, ("passenger_class", "number_of_parents_and_children")
6765
]
6866
}
6967
}
7068
)
71-
survived -- Label to predict
72-
( train'
73-
|> D.exclude [F.name passengerid]
74-
)
69+
(F.fromMaybe 0 (F.col @(Maybe Int) "Survived"))
70+
(trainDf |> D.exclude ["PassengerId"])
7571

7672
print model
7773

7874
putStrLn "Training accuracy: "
79-
print $
80-
computeAccuracy
81-
(train' |> D.derive (F.name prediction) model)
75+
print $ computeAccuracy (trainDf |> D.derive (F.name prediction) model)
8276

8377
putStrLn "Validation accuracy: "
84-
print $
85-
computeAccuracy
86-
( validation
87-
|> D.derive (F.name prediction) model
88-
)
78+
print $ computeAccuracy (validDf |> D.derive (F.name prediction) model)
8979

90-
let predictions = D.derive (F.name survived) model test'
9180
D.writeCsv
9281
"./predictions.csv"
93-
(predictions |> D.select [F.name passengerid, F.name survived])
82+
( testDf
83+
|> D.derive "Survived" model
84+
|> D.select ["PassengerId", "Survived"]
85+
)
9486

95-
prediction :: D.Expr Int
96-
prediction = F.col @Int "prediction"
87+
clean ::
88+
( DT.AssertPresent "Ticket" cols
89+
, DT.Lookup "Ticket" cols ~ Maybe T.Text
90+
, DT.AssertPresent "Name" cols
91+
, DT.Lookup "Name" cols ~ Maybe T.Text
92+
, DT.AssertPresent "Cabin" cols
93+
, DT.Lookup "Cabin" cols ~ Maybe T.Text
94+
) =>
95+
DT.TypedDataFrame cols ->
96+
DT.TypedDataFrame
97+
( DT.RenameManyInSchema
98+
'[ '("Name", "title")
99+
, '("Cabin", "cabin_prefix")
100+
, '("Pclass", "passenger_class")
101+
, '("SibSp", "number_of_siblings_and_spouses")
102+
, '("Parch", "number_of_parents_and_children")
103+
]
104+
cols
105+
)
106+
clean tdf =
107+
tdf
108+
|> DT.replaceColumn @"Ticket" (DT.nullLift (T.filter isAlpha) (DT.col @"Ticket"))
109+
|> DT.replaceColumn @"Name" (DT.nullLift extractTitle (DT.col @"Name"))
110+
|> DT.replaceColumn @"Cabin" (DT.nullLift (T.take 1) (DT.col @"Cabin"))
111+
|> DT.renameMany
112+
@'[ '("Name", "title")
113+
, '("Cabin", "cabin_prefix")
114+
, '("Pclass", "passenger_class")
115+
, '("SibSp", "number_of_siblings_and_spouses")
116+
, '("Parch", "number_of_parents_and_children")
117+
]
118+
119+
-- | Extract title (e.g. "Mr", "Mrs") from a full Titanic passenger name.
120+
extractTitle :: T.Text -> T.Text
121+
extractTitle name =
122+
case filter (T.isSuffixOf ".") (T.words name) of
123+
(w : _) -> T.dropEnd 1 w
124+
[] -> ""
97125

126+
{- | Compute binary classification accuracy from a DataFrame containing
127+
"Survived" and "prediction" columns.
128+
-}
98129
computeAccuracy :: D.DataFrame -> Double
99130
computeAccuracy df =
100-
let
101-
tp =
102-
fromIntegral $
103-
D.nRows
104-
( D.filterWhere
105-
(survived .== F.lit (1 :: Int) .&& prediction .== F.lit (1 :: Int))
106-
df
107-
)
108-
tn =
109-
fromIntegral $
110-
D.nRows
111-
( D.filterWhere
112-
(survived .== F.lit (0 :: Int) .&& prediction .== F.lit (0 :: Int))
113-
df
114-
)
115-
fp =
116-
fromIntegral $
117-
D.nRows
118-
( D.filterWhere
119-
(survived .== F.lit (0 :: Int) .&& prediction .== F.lit (1 :: Int))
120-
df
121-
)
122-
fn =
123-
fromIntegral $
124-
D.nRows
125-
( D.filterWhere
126-
(survived .== F.lit (1 :: Int) .&& prediction .== F.lit (0 :: Int))
127-
df
128-
)
129-
in
130-
(tp + tn) / (tp + tn + fp + fn)
131+
let tdf =
132+
DT.impute @"Survived" 0 $
133+
DT.unsafeFreeze @RawPredSchema $
134+
df |> D.select ["Survived", "prediction"]
135+
survived = DT.col @"Survived"
136+
pred = DT.col @"prediction"
137+
count expr = fromIntegral (DT.nRows (DT.filterWhere expr tdf))
138+
tp = count ((survived DT..==. DT.lit 1) DT..&&. (pred DT..==. DT.lit 1))
139+
tn = count ((survived DT..==. DT.lit 0) DT..&&. (pred DT..==. DT.lit 0))
140+
fp = count ((survived DT..==. DT.lit 0) DT..&&. (pred DT..==. DT.lit 1))
141+
fn = count ((survived DT..==. DT.lit 1) DT..&&. (pred DT..==. DT.lit 0))
142+
in (tp + tn) / (tp + tn + fp + fn)

0 commit comments

Comments
 (0)