2025-06-07 10:45:09 +08:00
|
|
|
|
import sys
|
2025-06-13 17:14:03 +08:00
|
|
|
|
import logging
|
|
|
|
|
|
from utils.config_loader import ConfigLoader
|
2025-06-07 10:45:09 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import psycopg2
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
psycopg2 = None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
sqlite3 = None
|
|
|
|
|
|
|
2025-06-13 17:14:03 +08:00
|
|
|
|
try:
|
|
|
|
|
|
import mysql.connector
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
mysql = None
|
|
|
|
|
|
|
2025-06-07 10:45:09 +08:00
|
|
|
|
|
|
|
|
|
|
class SQLUtils:
|
2025-06-13 17:14:03 +08:00
|
|
|
|
# 存储连接池,避免重复创建连接
|
|
|
|
|
|
_connection_pool = {}
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_type=None, source_name=None, **kwargs):
|
|
|
|
|
|
"""初始化SQLUtils对象
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一,如果为None则使用配置文件中的默认数据源
|
|
|
|
|
|
source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql'
|
|
|
|
|
|
**kwargs: 连接参数,如果没有提供,则使用配置文件中的参数
|
|
|
|
|
|
"""
|
2025-06-07 10:45:09 +08:00
|
|
|
|
self.conn = None
|
|
|
|
|
|
self.cursor = None
|
2025-06-13 17:14:03 +08:00
|
|
|
|
|
|
|
|
|
|
# 如果指定了source_name,直接使用该名称的数据源配置
|
|
|
|
|
|
if source_name:
|
|
|
|
|
|
config_loader = ConfigLoader.get_instance()
|
|
|
|
|
|
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
|
|
|
|
|
|
if not db_config:
|
|
|
|
|
|
raise ValueError(f"未找到数据源配置: {source_name}")
|
|
|
|
|
|
|
|
|
|
|
|
db_type = source_name # 数据源名称同时也是类型
|
|
|
|
|
|
|
|
|
|
|
|
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
|
|
|
|
|
|
if source_name == 'sqlite':
|
|
|
|
|
|
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
|
|
|
|
|
|
else:
|
|
|
|
|
|
kwargs = {
|
|
|
|
|
|
'host': db_config.get('host', 'localhost'),
|
|
|
|
|
|
'user': db_config.get('user', ''),
|
|
|
|
|
|
'password': db_config.get('password', ''),
|
|
|
|
|
|
'database': db_config.get('name', 'jtDB')
|
|
|
|
|
|
}
|
|
|
|
|
|
if 'port' in db_config and db_config['port']:
|
|
|
|
|
|
kwargs['port'] = int(db_config['port'])
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
|
|
|
|
|
|
elif db_type is None:
|
|
|
|
|
|
config_loader = ConfigLoader.get_instance()
|
|
|
|
|
|
default_source = config_loader.get_value('database.default', 'sqlite')
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有提供连接参数,则从配置文件获取
|
|
|
|
|
|
if not kwargs:
|
|
|
|
|
|
db_config = config_loader.get_database_config(default_source)
|
|
|
|
|
|
if default_source == 'sqlite':
|
|
|
|
|
|
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
|
|
|
|
|
|
else:
|
|
|
|
|
|
kwargs = {
|
|
|
|
|
|
'host': db_config.get('host', 'localhost'),
|
|
|
|
|
|
'user': db_config.get('user', ''),
|
|
|
|
|
|
'password': db_config.get('password', ''),
|
|
|
|
|
|
'database': db_config.get('name', 'jtDB')
|
|
|
|
|
|
}
|
|
|
|
|
|
if 'port' in db_config and db_config['port']:
|
|
|
|
|
|
kwargs['port'] = int(db_config['port'])
|
|
|
|
|
|
|
|
|
|
|
|
db_type = default_source
|
|
|
|
|
|
|
|
|
|
|
|
self.db_type = db_type.lower()
|
2025-06-07 10:45:09 +08:00
|
|
|
|
self.kwargs = kwargs
|
2025-06-13 17:14:03 +08:00
|
|
|
|
self.source_name = source_name or self.db_type
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试从连接池获取连接,如果没有则创建新连接
|
|
|
|
|
|
self._get_connection()
|
|
|
|
|
|
|
|
|
|
|
|
def _get_connection(self):
|
|
|
|
|
|
"""从连接池获取连接,如果没有则创建新连接"""
|
|
|
|
|
|
# 创建连接键,包含数据库类型和连接参数
|
|
|
|
|
|
conn_key = f"{self.db_type}:{str(self.kwargs)}"
|
|
|
|
|
|
|
|
|
|
|
|
# 检查连接池中是否已有此连接
|
|
|
|
|
|
if conn_key in SQLUtils._connection_pool:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 尝试执行简单查询,确认连接有效
|
|
|
|
|
|
conn, cursor = SQLUtils._connection_pool[conn_key]
|
|
|
|
|
|
cursor.execute("SELECT 1")
|
|
|
|
|
|
# 连接有效,直接使用
|
|
|
|
|
|
self.conn = conn
|
|
|
|
|
|
self.cursor = cursor
|
|
|
|
|
|
return
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
# 连接已失效,从连接池移除
|
|
|
|
|
|
del SQLUtils._connection_pool[conn_key]
|
|
|
|
|
|
|
|
|
|
|
|
# 创建新连接
|
2025-06-07 10:45:09 +08:00
|
|
|
|
self.connect()
|
2025-06-13 17:14:03 +08:00
|
|
|
|
|
|
|
|
|
|
# 将新连接添加到连接池
|
|
|
|
|
|
if self.conn and self.cursor:
|
|
|
|
|
|
SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor)
|
2025-06-07 10:45:09 +08:00
|
|
|
|
|
|
|
|
|
|
def connect(self):
|
2025-06-13 17:14:03 +08:00
|
|
|
|
"""连接到数据库"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self.db_type in ['pgsql', 'postgresql']:
|
|
|
|
|
|
if not psycopg2:
|
|
|
|
|
|
raise ImportError('psycopg2 is not installed')
|
|
|
|
|
|
self.conn = psycopg2.connect(**self.kwargs)
|
|
|
|
|
|
elif self.db_type in ['sqlite', 'sqlite3']:
|
|
|
|
|
|
if not sqlite3:
|
|
|
|
|
|
raise ImportError('sqlite3 is not installed')
|
|
|
|
|
|
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'))
|
|
|
|
|
|
elif self.db_type == 'mysql':
|
|
|
|
|
|
if not mysql:
|
|
|
|
|
|
raise ImportError('mysql.connector is not installed')
|
|
|
|
|
|
self.conn = mysql.connector.connect(**self.kwargs)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f'不支持的数据库类型: {self.db_type}')
|
|
|
|
|
|
|
|
|
|
|
|
self.cursor = self.conn.cursor()
|
|
|
|
|
|
logging.debug(f"成功连接到数据库: {self.db_type}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"连接数据库失败: {e}")
|
|
|
|
|
|
raise
|
2025-06-07 10:45:09 +08:00
|
|
|
|
|
|
|
|
|
|
def execute_query(self, sql, params=None):
|
|
|
|
|
|
if params is None:
|
|
|
|
|
|
params = ()
|
|
|
|
|
|
self.cursor.execute(sql, params)
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
def execute_update(self, sql, params=None):
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.cursor.execute(sql,params)
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.conn.rollback()
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
def begin_transaction(self) -> None:
|
|
|
|
|
|
"""开始事务"""
|
|
|
|
|
|
if self.db_type in ['sqlite', 'sqlite3']:
|
|
|
|
|
|
self.execute_query('BEGIN TRANSACTION')
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.conn.autocommit = False
|
|
|
|
|
|
|
|
|
|
|
|
def commit_transaction(self) -> None:
|
|
|
|
|
|
"""提交事务"""
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
if self.db_type not in ['sqlite', 'sqlite3']:
|
|
|
|
|
|
self.conn.autocommit = True
|
|
|
|
|
|
|
|
|
|
|
|
def rollback_transaction(self) -> None:
|
|
|
|
|
|
"""回滚事务"""
|
|
|
|
|
|
self.conn.rollback()
|
|
|
|
|
|
if self.db_type not in ['sqlite', 'sqlite3']:
|
|
|
|
|
|
self.conn.autocommit = True
|
|
|
|
|
|
|
|
|
|
|
|
def fetchone(self):
|
|
|
|
|
|
return self.cursor.fetchone()
|
|
|
|
|
|
|
|
|
|
|
|
def fetchall(self):
|
|
|
|
|
|
return self.cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
|
|
def close(self):
|
2025-06-13 17:14:03 +08:00
|
|
|
|
"""关闭连接(实际上是将连接返回到连接池)"""
|
|
|
|
|
|
# 这里不再实际关闭连接,让连接池管理连接生命周期
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def close_all_connections():
|
|
|
|
|
|
"""关闭所有连接池中的连接"""
|
|
|
|
|
|
for conn, cursor in SQLUtils._connection_pool.values():
|
|
|
|
|
|
try:
|
|
|
|
|
|
if cursor:
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
if conn:
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"关闭数据库连接失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
SQLUtils._connection_pool.clear()
|
|
|
|
|
|
logging.info("已关闭所有数据库连接")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_sqlite_connection():
|
|
|
|
|
|
"""获取SQLite连接"""
|
|
|
|
|
|
return SQLUtils(source_name='sqlite')
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_postgresql_connection():
|
|
|
|
|
|
"""获取PostgreSQL连接"""
|
|
|
|
|
|
return SQLUtils(source_name='postgresql')
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_mysql_connection():
|
|
|
|
|
|
"""获取MySQL连接"""
|
|
|
|
|
|
return SQLUtils(source_name='mysql')
|