import sys import logging import threading from utils.config_loader import ConfigLoader try: import psycopg2 except ImportError: psycopg2 = None try: import sqlite3 except ImportError: sqlite3 = None try: import mysql.connector except ImportError: mysql = None class SQLUtils: # 存储连接池,避免重复创建连接 _connection_pool = {} # 添加线程锁,用于防止多线程同时使用同一个cursor _lock = threading.RLock() def __init__(self, db_type=None, source_name=None, **kwargs): """初始化SQLUtils对象 Args: db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一,如果为None则使用配置文件中的默认数据源 source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql' **kwargs: 连接参数,如果没有提供,则使用配置文件中的参数 """ self.conn = None self.cursor = None # 如果指定了source_name,直接使用该名称的数据源配置 if source_name: config_loader = ConfigLoader.get_instance() db_config = config_loader.get_value(f'database.sources.{source_name}', {}) if not db_config: raise ValueError(f"未找到数据源配置: {source_name}") db_type = source_name # 数据源名称同时也是类型 if not kwargs: # 如果没有提供连接参数,则使用配置中的参数 if source_name == 'sqlite': kwargs = {'database': db_config.get('path', 'db/jtDB.db')} else: kwargs = { 'host': db_config.get('host', 'localhost'), 'user': db_config.get('user', ''), 'password': db_config.get('password', ''), 'database': db_config.get('name', 'jtDB') } if 'port' in db_config and db_config['port']: kwargs['port'] = int(db_config['port']) # 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源 elif db_type is None: config_loader = ConfigLoader.get_instance() default_source = config_loader.get_value('database.default', 'sqlite') # 如果没有提供连接参数,则从配置文件获取 if not kwargs: db_config = config_loader.get_database_config(default_source) if default_source == 'sqlite': kwargs = {'database': db_config.get('path', 'db/jtDB.db')} else: kwargs = { 'host': db_config.get('host', 'localhost'), 'user': db_config.get('user', ''), 'password': db_config.get('password', ''), 'database': db_config.get('name', 'jtDB') } if 'port' in db_config and db_config['port']: kwargs['port'] = int(db_config['port']) db_type = default_source self.db_type = db_type.lower() self.kwargs = kwargs self.source_name = source_name or self.db_type # 尝试从连接池获取连接,如果没有则创建新连接 self._get_connection() def __enter__(self): """上下文管理器入口方法,支持with语句""" return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器退出方法,自动关闭游标""" self.close() return False # 不抑制异常 def _get_connection(self): """从连接池获取连接,如果没有则创建新连接""" # 创建连接键,包含数据库类型和连接参数 conn_key = f"{self.db_type}:{str(self.kwargs)}" # 检查连接池中是否已有此连接 with SQLUtils._lock: if conn_key in SQLUtils._connection_pool: try: # 尝试执行简单查询,确认连接有效 conn, cursor = SQLUtils._connection_pool[conn_key] cursor.execute("SELECT 1") # 连接有效,直接使用 self.conn = conn self.cursor = cursor return except Exception: # 连接已失效,从连接池移除 del SQLUtils._connection_pool[conn_key] # 创建新连接 self.connect() # 将新连接添加到连接池 if self.conn and self.cursor: SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor) def connect(self): """连接到数据库""" try: if self.db_type in ['pgsql', 'postgresql']: if not psycopg2: raise ImportError('psycopg2 is not installed') self.conn = psycopg2.connect(**self.kwargs) elif self.db_type in ['sqlite', 'sqlite3']: if not sqlite3: raise ImportError('sqlite3 is not installed') self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'), check_same_thread=False) elif self.db_type == 'mysql': if not mysql: raise ImportError('mysql.connector is not installed') self.conn = mysql.connector.connect(**self.kwargs) else: raise ValueError(f'不支持的数据库类型: {self.db_type}') self.cursor = self.conn.cursor() logging.debug(f"成功连接到数据库: {self.db_type}") except Exception as e: logging.error(f"连接数据库失败: {e}") raise def execute_query(self, sql, params=None): if params is None: params = () with SQLUtils._lock: try: self.cursor.execute(sql, params) return self.cursor except Exception as e: logging.error(f"执行查询失败: {e}, SQL: {sql}, 参数: {params}") raise def execute_update(self, sql, params=None): try: with SQLUtils._lock: self.cursor.execute(sql, params) self.conn.commit() return self.cursor.rowcount except Exception as e: self.conn.rollback() logging.error(f"执行更新失败: {e}, SQL: {sql}, 参数: {params}") raise e def begin_transaction(self) -> None: """开始事务""" with SQLUtils._lock: if self.db_type in ['sqlite', 'sqlite3']: self.execute_query('BEGIN TRANSACTION') else: self.conn.autocommit = False def commit_transaction(self) -> None: """提交事务""" with SQLUtils._lock: self.conn.commit() if self.db_type not in ['sqlite', 'sqlite3']: self.conn.autocommit = True def rollback_transaction(self) -> None: """回滚事务""" with SQLUtils._lock: self.conn.rollback() if self.db_type not in ['sqlite', 'sqlite3']: self.conn.autocommit = True def fetchone(self): with SQLUtils._lock: return self.cursor.fetchone() def fetchall(self): with SQLUtils._lock: return self.cursor.fetchall() def close(self): """关闭当前游标,但保留连接在连接池中""" # 不再关闭连接,只关闭游标,减少重复创建连接的开销 # 连接仍然保留在连接池中供后续使用 pass def real_close(self): """真正关闭连接,从连接池中移除""" conn_key = f"{self.db_type}:{str(self.kwargs)}" with SQLUtils._lock: if conn_key in SQLUtils._connection_pool: try: conn, cursor = SQLUtils._connection_pool[conn_key] if cursor: cursor.close() if conn: conn.close() del SQLUtils._connection_pool[conn_key] logging.debug(f"已关闭并移除连接: {conn_key}") except Exception as e: logging.error(f"关闭连接失败: {e}") @staticmethod def close_all_connections(): """关闭所有连接池中的连接""" with SQLUtils._lock: for conn, cursor in SQLUtils._connection_pool.values(): try: if cursor: cursor.close() if conn: conn.close() except Exception as e: logging.error(f"关闭数据库连接失败: {e}") SQLUtils._connection_pool.clear() logging.info("已关闭所有数据库连接") @staticmethod def get_sqlite_connection(): """获取SQLite连接""" return SQLUtils(source_name='sqlite') @staticmethod def get_postgresql_connection(): """获取PostgreSQL连接""" return SQLUtils(source_name='postgresql') @staticmethod def get_mysql_connection(): """获取MySQL连接""" return SQLUtils(source_name='mysql')