-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcelestial_classification.py
More file actions
89 lines (70 loc) · 2.33 KB
/
celestial_classification.py
File metadata and controls
89 lines (70 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Importing the libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
# Load dataset
df = pd.read_csv('star_classification_1.csv')
# Display the first few rows
print(df.head())
# Display dataset information
print(df.info())
# Check for missing values
print(df.isnull().sum())
# Data processing
# Drop rows with missing values
df = df.dropna()
# Feature columns and target
features = ['u', 'g', 'r', 'i', 'z']
X = df[features]
y = df['class']
# Encode the target variable (if necessary)
y = pd.factorize(y)[0]
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Build and Train the Model
# Initialize RandomForestClassifier
model = RandomForestClassifier()
# Train the model
model.fit(X_train, y_train)
#Evaluate the Model
# Predict on test data
y_pred = model.predict(X_test)
# Print classification report
print(classification_report(y_test, y_pred))
# Print confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print(conf_matrix)
#Hyperparameter Tuning
# Define parameter grid for GridSearch
param_grid = {
'n_estimators': [100, 200, 300],
'max_depth': [10, 20, 30],
'min_samples_split': [2, 5, 10]
}
# Initialize GridSearchCV
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)
grid_search.fit(X_train, y_train)
# Print best parameters
print(f'Best Parameters: {grid_search.best_params_}')
#Visualize Results
# Plot confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=df['class'].unique(), yticklabels=df['class'].unique())
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
#Feature Importance Visualization
# Plot feature importance
importances = model.feature_importances_
feature_names = features
plt.figure(figsize=(10, 6))
plt.barh(feature_names, importances, color='skyblue')
plt.xlabel('Feature Importance')
plt.ylabel('Features')
plt.title('Feature Importance in Celestial Object Classification')
plt.show()