253 lines
9.5 KiB
Python
253 lines
9.5 KiB
Python
import sys
|
||
import logging
|
||
import threading
|
||
from utils.config_loader import ConfigLoader
|
||
|
||
try:
|
||
import psycopg2
|
||
except ImportError:
|
||
psycopg2 = None
|
||
|
||
try:
|
||
import sqlite3
|
||
except ImportError:
|
||
sqlite3 = None
|
||
|
||
try:
|
||
import mysql.connector
|
||
except ImportError:
|
||
mysql = None
|
||
|
||
|
||
class SQLUtils:
|
||
# 存储连接池,避免重复创建连接
|
||
_connection_pool = {}
|
||
# 添加线程锁,用于防止多线程同时使用同一个cursor
|
||
_lock = threading.RLock()
|
||
|
||
def __init__(self, db_type=None, source_name=None, **kwargs):
|
||
"""初始化SQLUtils对象
|
||
|
||
Args:
|
||
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一,如果为None则使用配置文件中的默认数据源
|
||
source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql'
|
||
**kwargs: 连接参数,如果没有提供,则使用配置文件中的参数
|
||
"""
|
||
self.conn = None
|
||
self.cursor = None
|
||
|
||
# 如果指定了source_name,直接使用该名称的数据源配置
|
||
if source_name:
|
||
config_loader = ConfigLoader.get_instance()
|
||
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
|
||
if not db_config:
|
||
raise ValueError(f"未找到数据源配置: {source_name}")
|
||
|
||
db_type = source_name # 数据源名称同时也是类型
|
||
|
||
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
|
||
if source_name == 'sqlite':
|
||
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
|
||
else:
|
||
kwargs = {
|
||
'host': db_config.get('host', 'localhost'),
|
||
'user': db_config.get('user', ''),
|
||
'password': db_config.get('password', ''),
|
||
'database': db_config.get('name', 'jtDB')
|
||
}
|
||
if 'port' in db_config and db_config['port']:
|
||
kwargs['port'] = int(db_config['port'])
|
||
|
||
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
|
||
elif db_type is None:
|
||
config_loader = ConfigLoader.get_instance()
|
||
default_source = config_loader.get_value('database.default', 'sqlite')
|
||
|
||
# 如果没有提供连接参数,则从配置文件获取
|
||
if not kwargs:
|
||
db_config = config_loader.get_database_config(default_source)
|
||
if default_source == 'sqlite':
|
||
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
|
||
else:
|
||
kwargs = {
|
||
'host': db_config.get('host', 'localhost'),
|
||
'user': db_config.get('user', ''),
|
||
'password': db_config.get('password', ''),
|
||
'database': db_config.get('name', 'jtDB')
|
||
}
|
||
if 'port' in db_config and db_config['port']:
|
||
kwargs['port'] = int(db_config['port'])
|
||
|
||
db_type = default_source
|
||
|
||
self.db_type = db_type.lower()
|
||
self.kwargs = kwargs
|
||
self.source_name = source_name or self.db_type
|
||
|
||
# 尝试从连接池获取连接,如果没有则创建新连接
|
||
self._get_connection()
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口方法,支持with语句"""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器退出方法,自动关闭游标"""
|
||
self.close()
|
||
return False # 不抑制异常
|
||
|
||
def _get_connection(self):
|
||
"""从连接池获取连接,如果没有则创建新连接"""
|
||
# 创建连接键,包含数据库类型和连接参数
|
||
conn_key = f"{self.db_type}:{str(self.kwargs)}"
|
||
|
||
# 检查连接池中是否已有此连接
|
||
with SQLUtils._lock:
|
||
if conn_key in SQLUtils._connection_pool:
|
||
try:
|
||
# 尝试执行简单查询,确认连接有效
|
||
conn, cursor = SQLUtils._connection_pool[conn_key]
|
||
cursor.execute("SELECT 1")
|
||
# 连接有效,直接使用
|
||
self.conn = conn
|
||
self.cursor = cursor
|
||
return
|
||
except Exception:
|
||
# 连接已失效,从连接池移除
|
||
del SQLUtils._connection_pool[conn_key]
|
||
|
||
# 创建新连接
|
||
self.connect()
|
||
|
||
# 将新连接添加到连接池
|
||
if self.conn and self.cursor:
|
||
SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor)
|
||
|
||
def connect(self):
|
||
"""连接到数据库"""
|
||
try:
|
||
if self.db_type in ['pgsql', 'postgresql']:
|
||
if not psycopg2:
|
||
raise ImportError('psycopg2 is not installed')
|
||
self.conn = psycopg2.connect(**self.kwargs)
|
||
elif self.db_type in ['sqlite', 'sqlite3']:
|
||
if not sqlite3:
|
||
raise ImportError('sqlite3 is not installed')
|
||
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'), check_same_thread=False)
|
||
elif self.db_type == 'mysql':
|
||
if not mysql:
|
||
raise ImportError('mysql.connector is not installed')
|
||
self.conn = mysql.connector.connect(**self.kwargs)
|
||
else:
|
||
raise ValueError(f'不支持的数据库类型: {self.db_type}')
|
||
|
||
self.cursor = self.conn.cursor()
|
||
logging.debug(f"成功连接到数据库: {self.db_type}")
|
||
except Exception as e:
|
||
logging.error(f"连接数据库失败: {e}")
|
||
raise
|
||
|
||
def execute_query(self, sql, params=None):
|
||
if params is None:
|
||
params = ()
|
||
with SQLUtils._lock:
|
||
try:
|
||
self.cursor.execute(sql, params)
|
||
return self.cursor
|
||
except Exception as e:
|
||
logging.error(f"执行查询失败: {e}, SQL: {sql}, 参数: {params}")
|
||
raise
|
||
|
||
def execute_update(self, sql, params=None):
|
||
try:
|
||
with SQLUtils._lock:
|
||
self.cursor.execute(sql, params)
|
||
self.conn.commit()
|
||
return self.cursor.rowcount
|
||
except Exception as e:
|
||
self.conn.rollback()
|
||
logging.error(f"执行更新失败: {e}, SQL: {sql}, 参数: {params}")
|
||
raise e
|
||
|
||
def begin_transaction(self) -> None:
|
||
"""开始事务"""
|
||
with SQLUtils._lock:
|
||
if self.db_type in ['sqlite', 'sqlite3']:
|
||
self.execute_query('BEGIN TRANSACTION')
|
||
else:
|
||
self.conn.autocommit = False
|
||
|
||
def commit_transaction(self) -> None:
|
||
"""提交事务"""
|
||
with SQLUtils._lock:
|
||
self.conn.commit()
|
||
if self.db_type not in ['sqlite', 'sqlite3']:
|
||
self.conn.autocommit = True
|
||
|
||
def rollback_transaction(self) -> None:
|
||
"""回滚事务"""
|
||
with SQLUtils._lock:
|
||
self.conn.rollback()
|
||
if self.db_type not in ['sqlite', 'sqlite3']:
|
||
self.conn.autocommit = True
|
||
|
||
def fetchone(self):
|
||
with SQLUtils._lock:
|
||
return self.cursor.fetchone()
|
||
|
||
def fetchall(self):
|
||
with SQLUtils._lock:
|
||
return self.cursor.fetchall()
|
||
|
||
def close(self):
|
||
"""关闭当前游标,但保留连接在连接池中"""
|
||
# 不再关闭连接,只关闭游标,减少重复创建连接的开销
|
||
# 连接仍然保留在连接池中供后续使用
|
||
pass
|
||
|
||
def real_close(self):
|
||
"""真正关闭连接,从连接池中移除"""
|
||
conn_key = f"{self.db_type}:{str(self.kwargs)}"
|
||
with SQLUtils._lock:
|
||
if conn_key in SQLUtils._connection_pool:
|
||
try:
|
||
conn, cursor = SQLUtils._connection_pool[conn_key]
|
||
if cursor:
|
||
cursor.close()
|
||
if conn:
|
||
conn.close()
|
||
del SQLUtils._connection_pool[conn_key]
|
||
logging.debug(f"已关闭并移除连接: {conn_key}")
|
||
except Exception as e:
|
||
logging.error(f"关闭连接失败: {e}")
|
||
|
||
@staticmethod
|
||
def close_all_connections():
|
||
"""关闭所有连接池中的连接"""
|
||
with SQLUtils._lock:
|
||
for conn, cursor in SQLUtils._connection_pool.values():
|
||
try:
|
||
if cursor:
|
||
cursor.close()
|
||
if conn:
|
||
conn.close()
|
||
except Exception as e:
|
||
logging.error(f"关闭数据库连接失败: {e}")
|
||
|
||
SQLUtils._connection_pool.clear()
|
||
logging.info("已关闭所有数据库连接")
|
||
|
||
@staticmethod
|
||
def get_sqlite_connection():
|
||
"""获取SQLite连接"""
|
||
return SQLUtils(source_name='sqlite')
|
||
|
||
@staticmethod
|
||
def get_postgresql_connection():
|
||
"""获取PostgreSQL连接"""
|
||
return SQLUtils(source_name='postgresql')
|
||
|
||
@staticmethod
|
||
def get_mysql_connection():
|
||
"""获取MySQL连接"""
|
||
return SQLUtils(source_name='mysql') |