jiateng_ws/utils/sql_utils.py
2025-07-01 15:32:40 +08:00

444 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import logging
import threading
import time
from utils.config_loader import ConfigLoader
try:
import psycopg2
except ImportError:
psycopg2 = None
try:
import sqlite3
except ImportError:
sqlite3 = None
try:
import mysql.connector
except ImportError:
mysql = None
class SQLUtils:
# 存储连接池使用线程ID作为键的一部分
_connection_pool = {}
# 连接引用计数
_connection_refs = {}
# 最后使用时间记录
_last_used = {}
# 轻量级锁,仅用于连接池访问
_pool_lock = threading.RLock()
# 空闲连接超时时间(秒)
_idle_timeout = 300 # 5分钟
# 初始化清理线程标志
_cleanup_thread_started = False
@classmethod
def _start_cleanup_thread(cls):
"""启动清理空闲连接的后台线程"""
if cls._cleanup_thread_started:
return
def cleanup_idle_connections():
"""定期清理空闲连接的线程函数"""
logging.info("数据库连接清理线程已启动")
while True:
time.sleep(60) # 每分钟检查一次
try:
current_time = time.time()
with cls._pool_lock:
# 复制键列表,避免在迭代过程中修改字典
conn_keys = list(cls._connection_pool.keys())
for conn_key in conn_keys:
# 检查引用计数和最后使用时间
if (conn_key in cls._connection_refs and
cls._connection_refs[conn_key] <= 0 and
conn_key in cls._last_used and
current_time - cls._last_used[conn_key] > cls._idle_timeout):
try:
# 获取连接和游标
conn, cursor = cls._connection_pool[conn_key]
# 关闭资源
if cursor:
cursor.close()
if conn:
conn.close()
# 从所有集合中移除
cls._connection_pool.pop(conn_key, None)
cls._connection_refs.pop(conn_key, None)
cls._last_used.pop(conn_key, None)
logging.debug(f"已清理空闲连接: {conn_key}")
except Exception as e:
logging.error(f"清理空闲连接时出错: {e}")
except Exception as e:
logging.error(f"连接清理线程执行异常: {e}")
# 创建并启动后台线程
cleanup_thread = threading.Thread(
target=cleanup_idle_connections,
daemon=True,
name="DB-Connection-Cleanup"
)
cleanup_thread.start()
cls._cleanup_thread_started = True
def __init__(self, db_type=None, source_name=None, **kwargs):
"""初始化SQLUtils对象
Args:
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一如果为None则使用配置文件中的默认数据源
source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql'
**kwargs: 连接参数,如果没有提供,则使用配置文件中的参数
"""
# 确保清理线程已启动
if not SQLUtils._cleanup_thread_started:
SQLUtils._start_cleanup_thread()
self.conn = None
self.cursor = None
# 如果指定了source_name直接使用该名称的数据源配置
if source_name:
config_loader = ConfigLoader.get_instance()
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
if not db_config:
raise ValueError(f"未找到数据源配置: {source_name}")
db_type = source_name # 数据源名称同时也是类型
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
if source_name == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
elif db_type is None:
config_loader = ConfigLoader.get_instance()
default_source = config_loader.get_value('database.default', 'sqlite')
# 如果没有提供连接参数,则从配置文件获取
if not kwargs:
db_config = config_loader.get_database_config(default_source)
if default_source == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
db_type = default_source
self.db_type = db_type.lower()
self.kwargs = kwargs
self.source_name = source_name or self.db_type
# 尝试从连接池获取连接,如果没有则创建新连接
self._get_connection()
def __enter__(self):
"""上下文管理器入口方法支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器退出方法,自动关闭游标"""
self.close()
return False # 不抑制异常
def _get_connection(self):
"""从连接池获取连接基于线程ID"""
# 使用线程ID作为连接键的一部分
thread_id = threading.get_ident()
conn_key = f"{self.db_type}:{str(self.kwargs)}:{thread_id}"
# 检查连接池中是否已有此线程的连接
# 只在访问共享资源时使用锁
with self._pool_lock:
if conn_key in self._connection_pool:
try:
# 尝试执行简单查询,确认连接有效
conn, cursor = self._connection_pool[conn_key]
cursor.execute("SELECT 1")
# 连接有效,直接使用
self.conn = conn
self.cursor = cursor
# 更新引用计数和最后使用时间
self._connection_refs[conn_key] = self._connection_refs.get(conn_key, 0) + 1
self._last_used[conn_key] = time.time()
return
except Exception:
# 连接已失效,从连接池移除
self._cleanup_connection(conn_key)
# 创建新连接 - 这部分不需要锁
self.connect()
# 将新连接添加到连接池 - 需要锁
if self.conn and self.cursor:
with self._pool_lock:
self._connection_pool[conn_key] = (self.conn, self.cursor)
self._connection_refs[conn_key] = 1
self._last_used[conn_key] = time.time()
def _cleanup_connection(self, conn_key):
"""清理指定的连接"""
try:
if conn_key in self._connection_pool:
conn, cursor = self._connection_pool[conn_key]
if cursor:
try:
cursor.close()
except:
pass
if conn:
try:
conn.close()
except:
pass
# 移除相关引用
self._connection_pool.pop(conn_key, None)
self._connection_refs.pop(conn_key, None)
self._last_used.pop(conn_key, None)
logging.debug(f"已清理连接: {conn_key}")
except Exception as e:
logging.error(f"清理连接失败: {e}")
def connect(self):
"""连接到数据库 - 不需要全局锁"""
try:
if self.db_type in ['pgsql', 'postgresql']:
if not psycopg2:
raise ImportError('psycopg2 is not installed')
self.conn = psycopg2.connect(**self.kwargs)
elif self.db_type in ['sqlite', 'sqlite3']:
if not sqlite3:
raise ImportError('sqlite3 is not installed')
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'), check_same_thread=False)
elif self.db_type == 'mysql':
if not mysql:
raise ImportError('mysql.connector is not installed')
self.conn = mysql.connector.connect(**self.kwargs)
else:
raise ValueError(f'不支持的数据库类型: {self.db_type}')
self.cursor = self.conn.cursor()
logging.debug(f"成功连接到数据库: {self.db_type}")
except Exception as e:
logging.error(f"连接数据库失败: {e}")
raise
def execute_query(self, sql, params=None):
"""执行查询 - 不使用全局锁,仅使用单个连接"""
if params is None:
params = ()
try:
# 直接执行查询,因为每个线程有自己的连接
self.cursor.execute(sql, params)
# 更新最后使用时间
thread_id = threading.get_ident()
conn_key = f"{self.db_type}:{str(self.kwargs)}:{thread_id}"
with self._pool_lock:
if conn_key in self._last_used:
self._last_used[conn_key] = time.time()
return self.cursor
except Exception as e:
logging.error(f"执行查询失败: {e}, SQL: {sql}, 参数: {params}")
raise
def execute_update(self, sql, params=None):
"""执行更新 - 不使用全局锁,仅使用单个连接"""
try:
if params is None:
params = ()
# 直接执行更新
self.cursor.execute(sql, params)
self.conn.commit()
# 更新最后使用时间
thread_id = threading.get_ident()
conn_key = f"{self.db_type}:{str(self.kwargs)}:{thread_id}"
with self._pool_lock:
if conn_key in self._last_used:
self._last_used[conn_key] = time.time()
return self.cursor.rowcount
except Exception as e:
self.conn.rollback()
logging.error(f"执行更新失败: {e}, SQL: {sql}, 参数: {params}")
raise e
def begin_transaction(self) -> None:
"""开始事务 - 不使用全局锁"""
if self.db_type in ['sqlite', 'sqlite3']:
self.execute_query('BEGIN TRANSACTION')
else:
self.conn.autocommit = False
def commit_transaction(self) -> None:
"""提交事务 - 不使用全局锁"""
self.conn.commit()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def rollback_transaction(self) -> None:
"""回滚事务 - 不使用全局锁"""
self.conn.rollback()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def fetchone(self):
"""获取一行数据 - 不使用全局锁"""
return self.cursor.fetchone()
def fetchall(self):
"""获取所有数据 - 不使用全局锁"""
return self.cursor.fetchall()
def get_new_cursor(self):
"""获取一个新的游标,用于避免游标递归使用问题
Returns:
cursor: 数据库游标对象
"""
if self.conn:
return self.conn.cursor()
return None
def close(self):
"""关闭当前游标,减少引用计数,必要时释放连接"""
thread_id = threading.get_ident()
conn_key = f"{self.db_type}:{str(self.kwargs)}:{thread_id}"
with self._pool_lock:
if conn_key in self._connection_refs:
# 减少引用计数
self._connection_refs[conn_key] -= 1
# 如果引用计数为0关闭连接并从池中移除
if self._connection_refs[conn_key] <= 0:
self._cleanup_connection(conn_key)
def real_close(self):
"""强制关闭连接,无论引用计数"""
thread_id = threading.get_ident()
conn_key = f"{self.db_type}:{str(self.kwargs)}:{thread_id}"
with self._pool_lock:
self._cleanup_connection(conn_key)
@staticmethod
def close_all_connections():
"""关闭所有连接池中的连接"""
with SQLUtils._pool_lock:
conn_keys = list(SQLUtils._connection_pool.keys())
for conn_key in conn_keys:
try:
conn, cursor = SQLUtils._connection_pool[conn_key]
if cursor:
cursor.close()
if conn:
conn.close()
except Exception as e:
logging.error(f"关闭数据库连接失败: {e}")
# 清空所有字典
SQLUtils._connection_pool.clear()
SQLUtils._connection_refs.clear()
SQLUtils._last_used.clear()
logging.info("已关闭所有数据库连接")
@staticmethod
def get_sqlite_connection():
"""获取SQLite连接"""
return SQLUtils(source_name='sqlite')
@staticmethod
def get_postgresql_connection():
"""获取PostgreSQL连接"""
return SQLUtils(source_name='postgresql')
@staticmethod
def get_mysql_connection():
"""获取MySQL连接"""
return SQLUtils(source_name='mysql')
@classmethod
def get_connection_pool_stats(cls):
"""获取连接池统计信息
Returns:
dict: 包含连接池统计信息的字典
"""
with cls._pool_lock:
stats = {
'active_connections': len(cls._connection_pool),
'connection_details': [],
'connection_count_by_type': {},
'active_threads': {},
}
# 统计不同类型连接数量
for conn_key in cls._connection_pool:
parts = conn_key.split(':')
if len(parts) > 0:
db_type = parts[0]
stats['connection_count_by_type'][db_type] = stats['connection_count_by_type'].get(db_type, 0) + 1
# 获取线程ID
if len(parts) > 2:
thread_id = parts[2]
stats['active_threads'][thread_id] = stats['active_threads'].get(thread_id, 0) + 1
# 连接详情
refs = cls._connection_refs.get(conn_key, 0)
last_used = cls._last_used.get(conn_key, 0)
idle_time = time.time() - last_used if last_used else 0
stats['connection_details'].append({
'key': conn_key,
'references': refs,
'idle_time_seconds': int(idle_time),
'is_idle': refs <= 0
})
return stats
@classmethod
def log_connection_pool_status(cls):
"""记录当前连接池状态到日志"""
stats = cls.get_connection_pool_stats()
logging.info(f"数据库连接池状态: 活动连接数={stats['active_connections']}")
# 记录每种数据库类型的连接数
for db_type, count in stats['connection_count_by_type'].items():
logging.info(f" - {db_type}: {count}个连接")
# 记录空闲连接
idle_connections = [d for d in stats['connection_details'] if d['is_idle']]
if idle_connections:
logging.info(f" - 空闲连接: {len(idle_connections)}")
for conn in idle_connections:
logging.debug(f" * {conn['key']} (空闲{conn['idle_time_seconds']}秒)")
return stats