import sys import logging from utils.config_loader import ConfigLoader try: import psycopg2 except ImportError: psycopg2 = None try: import sqlite3 except ImportError: sqlite3 = None try: import mysql.connector except ImportError: mysql = None class SQLUtils: # 存储连接池,避免重复创建连接 _connection_pool = {} def __init__(self, db_type=None, source_name=None, **kwargs): """初始化SQLUtils对象 Args: db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一,如果为None则使用配置文件中的默认数据源 source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql' **kwargs: 连接参数,如果没有提供,则使用配置文件中的参数 """ self.conn = None self.cursor = None # 如果指定了source_name,直接使用该名称的数据源配置 if source_name: config_loader = ConfigLoader.get_instance() db_config = config_loader.get_value(f'database.sources.{source_name}', {}) if not db_config: raise ValueError(f"未找到数据源配置: {source_name}") db_type = source_name # 数据源名称同时也是类型 if not kwargs: # 如果没有提供连接参数,则使用配置中的参数 if source_name == 'sqlite': kwargs = {'database': db_config.get('path', 'db/jtDB.db')} else: kwargs = { 'host': db_config.get('host', 'localhost'), 'user': db_config.get('user', ''), 'password': db_config.get('password', ''), 'database': db_config.get('name', 'jtDB') } if 'port' in db_config and db_config['port']: kwargs['port'] = int(db_config['port']) # 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源 elif db_type is None: config_loader = ConfigLoader.get_instance() default_source = config_loader.get_value('database.default', 'sqlite') # 如果没有提供连接参数,则从配置文件获取 if not kwargs: db_config = config_loader.get_database_config(default_source) if default_source == 'sqlite': kwargs = {'database': db_config.get('path', 'db/jtDB.db')} else: kwargs = { 'host': db_config.get('host', 'localhost'), 'user': db_config.get('user', ''), 'password': db_config.get('password', ''), 'database': db_config.get('name', 'jtDB') } if 'port' in db_config and db_config['port']: kwargs['port'] = int(db_config['port']) db_type = default_source self.db_type = db_type.lower() self.kwargs = kwargs self.source_name = source_name or self.db_type # 尝试从连接池获取连接,如果没有则创建新连接 self._get_connection() def _get_connection(self): """从连接池获取连接,如果没有则创建新连接""" # 创建连接键,包含数据库类型和连接参数 conn_key = f"{self.db_type}:{str(self.kwargs)}" # 检查连接池中是否已有此连接 if conn_key in SQLUtils._connection_pool: try: # 尝试执行简单查询,确认连接有效 conn, cursor = SQLUtils._connection_pool[conn_key] cursor.execute("SELECT 1") # 连接有效,直接使用 self.conn = conn self.cursor = cursor return except Exception: # 连接已失效,从连接池移除 del SQLUtils._connection_pool[conn_key] # 创建新连接 self.connect() # 将新连接添加到连接池 if self.conn and self.cursor: SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor) def connect(self): """连接到数据库""" try: if self.db_type in ['pgsql', 'postgresql']: if not psycopg2: raise ImportError('psycopg2 is not installed') self.conn = psycopg2.connect(**self.kwargs) elif self.db_type in ['sqlite', 'sqlite3']: if not sqlite3: raise ImportError('sqlite3 is not installed') self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:')) elif self.db_type == 'mysql': if not mysql: raise ImportError('mysql.connector is not installed') self.conn = mysql.connector.connect(**self.kwargs) else: raise ValueError(f'不支持的数据库类型: {self.db_type}') self.cursor = self.conn.cursor() logging.debug(f"成功连接到数据库: {self.db_type}") except Exception as e: logging.error(f"连接数据库失败: {e}") raise def execute_query(self, sql, params=None): if params is None: params = () self.cursor.execute(sql, params) self.conn.commit() def execute_update(self, sql, params=None): try: self.cursor.execute(sql,params) self.conn.commit() except Exception as e: self.conn.rollback() raise e def begin_transaction(self) -> None: """开始事务""" if self.db_type in ['sqlite', 'sqlite3']: self.execute_query('BEGIN TRANSACTION') else: self.conn.autocommit = False def commit_transaction(self) -> None: """提交事务""" self.conn.commit() if self.db_type not in ['sqlite', 'sqlite3']: self.conn.autocommit = True def rollback_transaction(self) -> None: """回滚事务""" self.conn.rollback() if self.db_type not in ['sqlite', 'sqlite3']: self.conn.autocommit = True def fetchone(self): return self.cursor.fetchone() def fetchall(self): return self.cursor.fetchall() def close(self): """关闭连接(实际上是将连接返回到连接池)""" # 这里不再实际关闭连接,让连接池管理连接生命周期 pass @staticmethod def close_all_connections(): """关闭所有连接池中的连接""" for conn, cursor in SQLUtils._connection_pool.values(): try: if cursor: cursor.close() if conn: conn.close() except Exception as e: logging.error(f"关闭数据库连接失败: {e}") SQLUtils._connection_pool.clear() logging.info("已关闭所有数据库连接") @staticmethod def get_sqlite_connection(): """获取SQLite连接""" return SQLUtils(source_name='sqlite') @staticmethod def get_postgresql_connection(): """获取PostgreSQL连接""" return SQLUtils(source_name='postgresql') @staticmethod def get_mysql_connection(): """获取MySQL连接""" return SQLUtils(source_name='mysql')