jiateng_ws/utils/sql_utils.py

253 lines
9.5 KiB
Python
Raw Normal View History

2025-06-07 10:45:09 +08:00
import sys
2025-06-13 17:14:03 +08:00
import logging
2025-07-01 09:36:16 +08:00
import threading
2025-06-13 17:14:03 +08:00
from utils.config_loader import ConfigLoader
2025-06-07 10:45:09 +08:00
try:
import psycopg2
except ImportError:
psycopg2 = None
try:
import sqlite3
except ImportError:
sqlite3 = None
2025-06-13 17:14:03 +08:00
try:
import mysql.connector
except ImportError:
mysql = None
2025-06-07 10:45:09 +08:00
class SQLUtils:
2025-06-13 17:14:03 +08:00
# 存储连接池,避免重复创建连接
_connection_pool = {}
2025-07-01 09:36:16 +08:00
# 添加线程锁用于防止多线程同时使用同一个cursor
_lock = threading.RLock()
2025-06-13 17:14:03 +08:00
def __init__(self, db_type=None, source_name=None, **kwargs):
"""初始化SQLUtils对象
Args:
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一如果为None则使用配置文件中的默认数据源
source_name: 数据源名称用于从配置中获取特定的数据源'sqlite', 'postgresql', 'mysql'
**kwargs: 连接参数如果没有提供则使用配置文件中的参数
"""
2025-06-07 10:45:09 +08:00
self.conn = None
self.cursor = None
2025-06-13 17:14:03 +08:00
# 如果指定了source_name直接使用该名称的数据源配置
if source_name:
config_loader = ConfigLoader.get_instance()
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
if not db_config:
raise ValueError(f"未找到数据源配置: {source_name}")
db_type = source_name # 数据源名称同时也是类型
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
if source_name == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
elif db_type is None:
config_loader = ConfigLoader.get_instance()
default_source = config_loader.get_value('database.default', 'sqlite')
# 如果没有提供连接参数,则从配置文件获取
if not kwargs:
db_config = config_loader.get_database_config(default_source)
if default_source == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
db_type = default_source
self.db_type = db_type.lower()
2025-06-07 10:45:09 +08:00
self.kwargs = kwargs
2025-06-13 17:14:03 +08:00
self.source_name = source_name or self.db_type
# 尝试从连接池获取连接,如果没有则创建新连接
self._get_connection()
2025-07-01 09:36:16 +08:00
def __enter__(self):
"""上下文管理器入口方法支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器退出方法,自动关闭游标"""
self.close()
return False # 不抑制异常
2025-06-13 17:14:03 +08:00
def _get_connection(self):
"""从连接池获取连接,如果没有则创建新连接"""
# 创建连接键,包含数据库类型和连接参数
conn_key = f"{self.db_type}:{str(self.kwargs)}"
# 检查连接池中是否已有此连接
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
if conn_key in SQLUtils._connection_pool:
try:
# 尝试执行简单查询,确认连接有效
conn, cursor = SQLUtils._connection_pool[conn_key]
cursor.execute("SELECT 1")
# 连接有效,直接使用
self.conn = conn
self.cursor = cursor
return
except Exception:
# 连接已失效,从连接池移除
del SQLUtils._connection_pool[conn_key]
# 创建新连接
self.connect()
# 将新连接添加到连接池
if self.conn and self.cursor:
SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor)
2025-06-07 10:45:09 +08:00
def connect(self):
2025-06-13 17:14:03 +08:00
"""连接到数据库"""
try:
if self.db_type in ['pgsql', 'postgresql']:
if not psycopg2:
raise ImportError('psycopg2 is not installed')
self.conn = psycopg2.connect(**self.kwargs)
elif self.db_type in ['sqlite', 'sqlite3']:
if not sqlite3:
raise ImportError('sqlite3 is not installed')
2025-07-01 09:36:16 +08:00
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'), check_same_thread=False)
2025-06-13 17:14:03 +08:00
elif self.db_type == 'mysql':
if not mysql:
raise ImportError('mysql.connector is not installed')
self.conn = mysql.connector.connect(**self.kwargs)
else:
raise ValueError(f'不支持的数据库类型: {self.db_type}')
self.cursor = self.conn.cursor()
logging.debug(f"成功连接到数据库: {self.db_type}")
except Exception as e:
logging.error(f"连接数据库失败: {e}")
raise
2025-06-07 10:45:09 +08:00
def execute_query(self, sql, params=None):
if params is None:
params = ()
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
try:
self.cursor.execute(sql, params)
return self.cursor
except Exception as e:
logging.error(f"执行查询失败: {e}, SQL: {sql}, 参数: {params}")
raise
2025-06-07 10:45:09 +08:00
def execute_update(self, sql, params=None):
try:
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
self.cursor.execute(sql, params)
self.conn.commit()
return self.cursor.rowcount
2025-06-07 10:45:09 +08:00
except Exception as e:
self.conn.rollback()
2025-07-01 09:36:16 +08:00
logging.error(f"执行更新失败: {e}, SQL: {sql}, 参数: {params}")
2025-06-07 10:45:09 +08:00
raise e
def begin_transaction(self) -> None:
"""开始事务"""
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
if self.db_type in ['sqlite', 'sqlite3']:
self.execute_query('BEGIN TRANSACTION')
else:
self.conn.autocommit = False
2025-06-07 10:45:09 +08:00
def commit_transaction(self) -> None:
"""提交事务"""
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
self.conn.commit()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
2025-06-07 10:45:09 +08:00
def rollback_transaction(self) -> None:
"""回滚事务"""
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
self.conn.rollback()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
2025-06-07 10:45:09 +08:00
def fetchone(self):
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
return self.cursor.fetchone()
2025-06-07 10:45:09 +08:00
def fetchall(self):
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
return self.cursor.fetchall()
2025-06-07 10:45:09 +08:00
def close(self):
2025-07-01 09:36:16 +08:00
"""关闭当前游标,但保留连接在连接池中"""
# 不再关闭连接,只关闭游标,减少重复创建连接的开销
# 连接仍然保留在连接池中供后续使用
2025-06-13 17:14:03 +08:00
pass
2025-07-01 09:36:16 +08:00
def real_close(self):
"""真正关闭连接,从连接池中移除"""
conn_key = f"{self.db_type}:{str(self.kwargs)}"
with SQLUtils._lock:
if conn_key in SQLUtils._connection_pool:
try:
conn, cursor = SQLUtils._connection_pool[conn_key]
if cursor:
cursor.close()
if conn:
conn.close()
del SQLUtils._connection_pool[conn_key]
logging.debug(f"已关闭并移除连接: {conn_key}")
except Exception as e:
logging.error(f"关闭连接失败: {e}")
2025-06-13 17:14:03 +08:00
@staticmethod
def close_all_connections():
"""关闭所有连接池中的连接"""
2025-07-01 09:36:16 +08:00
with SQLUtils._lock:
for conn, cursor in SQLUtils._connection_pool.values():
try:
if cursor:
cursor.close()
if conn:
conn.close()
except Exception as e:
logging.error(f"关闭数据库连接失败: {e}")
SQLUtils._connection_pool.clear()
logging.info("已关闭所有数据库连接")
2025-06-13 17:14:03 +08:00
@staticmethod
def get_sqlite_connection():
"""获取SQLite连接"""
return SQLUtils(source_name='sqlite')
@staticmethod
def get_postgresql_connection():
"""获取PostgreSQL连接"""
return SQLUtils(source_name='postgresql')
@staticmethod
def get_mysql_connection():
"""获取MySQL连接"""
return SQLUtils(source_name='mysql')