-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml_test_new.py
More file actions
146 lines (107 loc) · 3.75 KB
/
ml_test_new.py
File metadata and controls
146 lines (107 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pandas as pd
import os
from multiprocessing import Pool
from astropy.coordinates import SkyCoord, ICRS
from astropy import units as u
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import cla
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.preprocessing import OneHotEncoder
from sklearn import preprocessing
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
def plotGenricSkyMap(coords):
"""
A generic function to plot a skymap for the given Sky coord array.
Args:
coords (numpy array): A numpy array of skycoord objects
"""
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="mollweide")
scatter = ax.scatter(-coords.ra.wrap_at(180 * u.deg).radian, coords.dec.wrap_at(180 * u.deg).radian, s=3, vmin=0)
ax.grid(True)
ax.set_xticklabels(['10h', '8h', '6h', '4h', '2h', '0h', '22h', '20h', '18h', '16h', '14h'])
plt.show()
mappings_bulla = pd.read_csv('test_data/MDF_VS_KN-KilonovaSims/Bulla/SNID_TO_SKYMAP.csv')
mappings_kasen = pd.read_csv('test_data/MDF_VS_KN-KilonovaSims/Kasen/SNID_TO_SKYMAP.csv')
l = []
for i in range(len(mappings_bulla)):
SNID = mappings_bulla['SNID'][i]
try:
path = f'Bulla_features/{SNID}.csv'
df = pd.read_csv(path)
l.append(df)
except:
print(f'Bulla SNID {SNID} not found')
for i in range(len(mappings_kasen)):
SNID = mappings_kasen['SNID'][i]
try:
path = f'Kasen_features/{SNID}.csv'
df = pd.read_csv(path)
l.append(df)
except:
print(f'Kasen SNID {SNID} not found')
df = pd.concat(l)
# Removing duplicate SNID. Happens because flares can be in the sky maps for multiple KN's
df = df.drop_duplicates(subset=['SNID'])
# c = SkyCoord(ra=df['RA'], dec=df['DEC'], frame=ICRS, unit='deg')
# plotGenricSkyMap(c)
# flares
df_1 = df[df['CLASS']=='MDF']
# KN
df_Kasen = df[df['CLASS']=='KN Kasen']
df_Bulla = df[df['CLASS']=='KN Bulla']
print(len(df_Bulla), len(df_Kasen))
df_2 = pd.concat([df_Bulla, df_Kasen])
df_1 = df_1[:len(df_2)]
# Merge the dataframs
df = pd.concat([df_1,df_2])
# Make the matrices
x = df[['BAND','PRE-BAND','POST-BAND']]
y = df[['CLASS']]
# One hot encoding for passband features
enc = OneHotEncoder(sparse=False)
x = enc.fit_transform(x)
# Adding the time to prev and next det as features
x_new = np.zeros((len(df), 24))
x_new[:, :20] = x
x_new[:, 20] = df['TIME-TO-PREV']
x_new[:, 21] = df['TIME-TO-NEXT']
x_new[:, 22] = df['MDF_DENSITY']
x_new[:, 23] = df['GW_PROB']
# x_new[:, 23] = df['NEXT-PHOT-FLAG']
# x_new[:, 24] = df['NUM_DETECTIONS']
features = list(enc.get_feature_names_out())
features.append('TIME-TO-PREV')
features.append('TIME-TO-NEXT')
features.append('MDF_DENSITY')
features.append('GW_PROB')
# features.append('NEXT-PHOT-FLAG')
# features.append('NUM_DETECTIONS')
x_new = pd.DataFrame(x_new, columns=features)
x = x_new
print(x_new)
# Creating binary class column
y = []
for c in df['CLASS']:
if c == 'KN Bulla' or c == 'KN Kasen':
y.append(0)
else:
y.append(1)
# Splitting data
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=42)
# Classifier
clf=RandomForestClassifier(n_estimators=1000, random_state=42)
clf.fit(X_train,y_train)
y_pred=clf.predict(X_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
cm = confusion_matrix(y_test, y_pred, labels=clf.classes_)
importance = clf.feature_importances_
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Kilo Nova', 'M Dwarf flare'])
disp.plot()
plt.show()
for i in range(len(importance)):
print(f'{features[i]}: {importance[i]}')