jiateng_ws/utils/sql_utils.py
2025-07-01 09:36:16 +08:00

253 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import logging
import threading
from utils.config_loader import ConfigLoader
try:
import psycopg2
except ImportError:
psycopg2 = None
try:
import sqlite3
except ImportError:
sqlite3 = None
try:
import mysql.connector
except ImportError:
mysql = None
class SQLUtils:
# 存储连接池,避免重复创建连接
_connection_pool = {}
# 添加线程锁用于防止多线程同时使用同一个cursor
_lock = threading.RLock()
def __init__(self, db_type=None, source_name=None, **kwargs):
"""初始化SQLUtils对象
Args:
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一如果为None则使用配置文件中的默认数据源
source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql'
**kwargs: 连接参数,如果没有提供,则使用配置文件中的参数
"""
self.conn = None
self.cursor = None
# 如果指定了source_name直接使用该名称的数据源配置
if source_name:
config_loader = ConfigLoader.get_instance()
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
if not db_config:
raise ValueError(f"未找到数据源配置: {source_name}")
db_type = source_name # 数据源名称同时也是类型
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
if source_name == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
elif db_type is None:
config_loader = ConfigLoader.get_instance()
default_source = config_loader.get_value('database.default', 'sqlite')
# 如果没有提供连接参数,则从配置文件获取
if not kwargs:
db_config = config_loader.get_database_config(default_source)
if default_source == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
db_type = default_source
self.db_type = db_type.lower()
self.kwargs = kwargs
self.source_name = source_name or self.db_type
# 尝试从连接池获取连接,如果没有则创建新连接
self._get_connection()
def __enter__(self):
"""上下文管理器入口方法支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器退出方法,自动关闭游标"""
self.close()
return False # 不抑制异常
def _get_connection(self):
"""从连接池获取连接,如果没有则创建新连接"""
# 创建连接键,包含数据库类型和连接参数
conn_key = f"{self.db_type}:{str(self.kwargs)}"
# 检查连接池中是否已有此连接
with SQLUtils._lock:
if conn_key in SQLUtils._connection_pool:
try:
# 尝试执行简单查询,确认连接有效
conn, cursor = SQLUtils._connection_pool[conn_key]
cursor.execute("SELECT 1")
# 连接有效,直接使用
self.conn = conn
self.cursor = cursor
return
except Exception:
# 连接已失效,从连接池移除
del SQLUtils._connection_pool[conn_key]
# 创建新连接
self.connect()
# 将新连接添加到连接池
if self.conn and self.cursor:
SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor)
def connect(self):
"""连接到数据库"""
try:
if self.db_type in ['pgsql', 'postgresql']:
if not psycopg2:
raise ImportError('psycopg2 is not installed')
self.conn = psycopg2.connect(**self.kwargs)
elif self.db_type in ['sqlite', 'sqlite3']:
if not sqlite3:
raise ImportError('sqlite3 is not installed')
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'), check_same_thread=False)
elif self.db_type == 'mysql':
if not mysql:
raise ImportError('mysql.connector is not installed')
self.conn = mysql.connector.connect(**self.kwargs)
else:
raise ValueError(f'不支持的数据库类型: {self.db_type}')
self.cursor = self.conn.cursor()
logging.debug(f"成功连接到数据库: {self.db_type}")
except Exception as e:
logging.error(f"连接数据库失败: {e}")
raise
def execute_query(self, sql, params=None):
if params is None:
params = ()
with SQLUtils._lock:
try:
self.cursor.execute(sql, params)
return self.cursor
except Exception as e:
logging.error(f"执行查询失败: {e}, SQL: {sql}, 参数: {params}")
raise
def execute_update(self, sql, params=None):
try:
with SQLUtils._lock:
self.cursor.execute(sql, params)
self.conn.commit()
return self.cursor.rowcount
except Exception as e:
self.conn.rollback()
logging.error(f"执行更新失败: {e}, SQL: {sql}, 参数: {params}")
raise e
def begin_transaction(self) -> None:
"""开始事务"""
with SQLUtils._lock:
if self.db_type in ['sqlite', 'sqlite3']:
self.execute_query('BEGIN TRANSACTION')
else:
self.conn.autocommit = False
def commit_transaction(self) -> None:
"""提交事务"""
with SQLUtils._lock:
self.conn.commit()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def rollback_transaction(self) -> None:
"""回滚事务"""
with SQLUtils._lock:
self.conn.rollback()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def fetchone(self):
with SQLUtils._lock:
return self.cursor.fetchone()
def fetchall(self):
with SQLUtils._lock:
return self.cursor.fetchall()
def close(self):
"""关闭当前游标,但保留连接在连接池中"""
# 不再关闭连接,只关闭游标,减少重复创建连接的开销
# 连接仍然保留在连接池中供后续使用
pass
def real_close(self):
"""真正关闭连接,从连接池中移除"""
conn_key = f"{self.db_type}:{str(self.kwargs)}"
with SQLUtils._lock:
if conn_key in SQLUtils._connection_pool:
try:
conn, cursor = SQLUtils._connection_pool[conn_key]
if cursor:
cursor.close()
if conn:
conn.close()
del SQLUtils._connection_pool[conn_key]
logging.debug(f"已关闭并移除连接: {conn_key}")
except Exception as e:
logging.error(f"关闭连接失败: {e}")
@staticmethod
def close_all_connections():
"""关闭所有连接池中的连接"""
with SQLUtils._lock:
for conn, cursor in SQLUtils._connection_pool.values():
try:
if cursor:
cursor.close()
if conn:
conn.close()
except Exception as e:
logging.error(f"关闭数据库连接失败: {e}")
SQLUtils._connection_pool.clear()
logging.info("已关闭所有数据库连接")
@staticmethod
def get_sqlite_connection():
"""获取SQLite连接"""
return SQLUtils(source_name='sqlite')
@staticmethod
def get_postgresql_connection():
"""获取PostgreSQL连接"""
return SQLUtils(source_name='postgresql')
@staticmethod
def get_mysql_connection():
"""获取MySQL连接"""
return SQLUtils(source_name='mysql')