-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_datefile.py
More file actions
122 lines (96 loc) · 3.5 KB
/
generate_datefile.py
File metadata and controls
122 lines (96 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import sqlite3
import os
import csv
from typing import List, Dict
import time
def parse_args():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Generate datefile.txt")
# 添加参数,解析CSV文件夹路径
parser.add_argument("--csv_folder", required=True, type=str, help="CSV folder path")
# 添加参数,解析输出的数据库文件路径
parser.add_argument(
"--output_db", required=True, type=str, help="Output database file path"
)
return parser.parse_args()
# 扫描csv_folder目录下的所有csv文件,返回一个dict,key为文件名,value为文件路径
def scan_csv_files(csv_folder_path: str) -> Dict[str, str]:
csv_filename_path_dict = {}
for root, _, files in os.walk(csv_folder_path):
for file in files:
if not file.endswith(".csv"):
continue
csv_filename_path_dict[file] = os.path.join(root, file)
return csv_filename_path_dict
def get_table_name(file_name: str) -> str:
"""
从文件名中提取表名。
该函数通过去掉文件名中的扩展名来提取表名。这对于数据库操作或文件处理时需要
根据文件名确定数据表名的场景特别有用。
参数:
file_name (str): 带有扩展名的文件名。
返回:
str: 去掉扩展名后的文件名,即表名。
"""
# 分割文件名,取第一个部分作为表名
return os.path.splitext(file_name)[0]
# sqlite3创建表
def create_table(conn: sqlite3.Connection, table_name: str, columns: List[str]):
cursor = conn.cursor()
sql = "Create Table if not exists {} ({})".format(
table_name, ", ".join([f"{col} TEXT" for col in columns])
)
print("\t", sql)
cursor.execute(sql)
conn.commit()
cursor.close()
# 将csv数据写入数据库
def write_csv_data_to_db(conn: sqlite3.Connection, table_name: str, csv_file_path: str):
start_time = time.time()
cursor = conn.cursor()
record_num = 0
with open(csv_file_path, "r") as f:
reader = csv.reader(f)
headers = [col.strip().strip('"') for col in next(reader)]
create_table(conn, table_name, headers)
rows_cache=[]
sql = f"INSERT INTO {table_name} VALUES ({', '.join(['?'] * len(headers))})"
for row in reader:
rows_cache.append([x.strip().strip('"') for x in row])
record_num += 1
if len(rows_cache) >= 10:
cursor.executemany(sql, rows_cache)
conn.commit()
rows_cache.clear()
if len(rows_cache) > 0:
cursor.executemany(sql, rows_cache)
conn.commit()
rows_cache.clear()
cursor.close()
print(
"\t",
table_name,
"insert",
record_num,
"rows, cost:",
time.time() - start_time,
"s",
)
return True
if __name__ == "__main__":
args = parse_args()
# 创建sqlite数据库连接
conn = sqlite3.connect(args.output_db)
# 优化sqlite写入性能
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA journal_mode = MEMORY")
for file_name, file_path in scan_csv_files(args.csv_folder).items():
table_name = get_table_name(file_name)
print(
"Creating and Writing data to table: {} ({})".format(table_name, file_name)
)
write_csv_data_to_db(conn, table_name, file_path)
conn.close()
print()
print("All done!")