jiateng_ws/utils/sql_utils.py
2025-06-13 17:14:03 +08:00

209 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import logging
from utils.config_loader import ConfigLoader
try:
import psycopg2
except ImportError:
psycopg2 = None
try:
import sqlite3
except ImportError:
sqlite3 = None
try:
import mysql.connector
except ImportError:
mysql = None
class SQLUtils:
# 存储连接池,避免重复创建连接
_connection_pool = {}
def __init__(self, db_type=None, source_name=None, **kwargs):
"""初始化SQLUtils对象
Args:
db_type: 数据库类型 'sqlite', 'postgresql', 'mysql' 之一如果为None则使用配置文件中的默认数据源
source_name: 数据源名称,用于从配置中获取特定的数据源,如'sqlite', 'postgresql', 'mysql'
**kwargs: 连接参数,如果没有提供,则使用配置文件中的参数
"""
self.conn = None
self.cursor = None
# 如果指定了source_name直接使用该名称的数据源配置
if source_name:
config_loader = ConfigLoader.get_instance()
db_config = config_loader.get_value(f'database.sources.{source_name}', {})
if not db_config:
raise ValueError(f"未找到数据源配置: {source_name}")
db_type = source_name # 数据源名称同时也是类型
if not kwargs: # 如果没有提供连接参数,则使用配置中的参数
if source_name == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
# 如果没有指定数据库类型和数据源名称,则使用配置中的默认数据源
elif db_type is None:
config_loader = ConfigLoader.get_instance()
default_source = config_loader.get_value('database.default', 'sqlite')
# 如果没有提供连接参数,则从配置文件获取
if not kwargs:
db_config = config_loader.get_database_config(default_source)
if default_source == 'sqlite':
kwargs = {'database': db_config.get('path', 'db/jtDB.db')}
else:
kwargs = {
'host': db_config.get('host', 'localhost'),
'user': db_config.get('user', ''),
'password': db_config.get('password', ''),
'database': db_config.get('name', 'jtDB')
}
if 'port' in db_config and db_config['port']:
kwargs['port'] = int(db_config['port'])
db_type = default_source
self.db_type = db_type.lower()
self.kwargs = kwargs
self.source_name = source_name or self.db_type
# 尝试从连接池获取连接,如果没有则创建新连接
self._get_connection()
def _get_connection(self):
"""从连接池获取连接,如果没有则创建新连接"""
# 创建连接键,包含数据库类型和连接参数
conn_key = f"{self.db_type}:{str(self.kwargs)}"
# 检查连接池中是否已有此连接
if conn_key in SQLUtils._connection_pool:
try:
# 尝试执行简单查询,确认连接有效
conn, cursor = SQLUtils._connection_pool[conn_key]
cursor.execute("SELECT 1")
# 连接有效,直接使用
self.conn = conn
self.cursor = cursor
return
except Exception:
# 连接已失效,从连接池移除
del SQLUtils._connection_pool[conn_key]
# 创建新连接
self.connect()
# 将新连接添加到连接池
if self.conn and self.cursor:
SQLUtils._connection_pool[conn_key] = (self.conn, self.cursor)
def connect(self):
"""连接到数据库"""
try:
if self.db_type in ['pgsql', 'postgresql']:
if not psycopg2:
raise ImportError('psycopg2 is not installed')
self.conn = psycopg2.connect(**self.kwargs)
elif self.db_type in ['sqlite', 'sqlite3']:
if not sqlite3:
raise ImportError('sqlite3 is not installed')
self.conn = sqlite3.connect(self.kwargs.get('database', ':memory:'))
elif self.db_type == 'mysql':
if not mysql:
raise ImportError('mysql.connector is not installed')
self.conn = mysql.connector.connect(**self.kwargs)
else:
raise ValueError(f'不支持的数据库类型: {self.db_type}')
self.cursor = self.conn.cursor()
logging.debug(f"成功连接到数据库: {self.db_type}")
except Exception as e:
logging.error(f"连接数据库失败: {e}")
raise
def execute_query(self, sql, params=None):
if params is None:
params = ()
self.cursor.execute(sql, params)
self.conn.commit()
def execute_update(self, sql, params=None):
try:
self.cursor.execute(sql,params)
self.conn.commit()
except Exception as e:
self.conn.rollback()
raise e
def begin_transaction(self) -> None:
"""开始事务"""
if self.db_type in ['sqlite', 'sqlite3']:
self.execute_query('BEGIN TRANSACTION')
else:
self.conn.autocommit = False
def commit_transaction(self) -> None:
"""提交事务"""
self.conn.commit()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def rollback_transaction(self) -> None:
"""回滚事务"""
self.conn.rollback()
if self.db_type not in ['sqlite', 'sqlite3']:
self.conn.autocommit = True
def fetchone(self):
return self.cursor.fetchone()
def fetchall(self):
return self.cursor.fetchall()
def close(self):
"""关闭连接(实际上是将连接返回到连接池)"""
# 这里不再实际关闭连接,让连接池管理连接生命周期
pass
@staticmethod
def close_all_connections():
"""关闭所有连接池中的连接"""
for conn, cursor in SQLUtils._connection_pool.values():
try:
if cursor:
cursor.close()
if conn:
conn.close()
except Exception as e:
logging.error(f"关闭数据库连接失败: {e}")
SQLUtils._connection_pool.clear()
logging.info("已关闭所有数据库连接")
@staticmethod
def get_sqlite_connection():
"""获取SQLite连接"""
return SQLUtils(source_name='sqlite')
@staticmethod
def get_postgresql_connection():
"""获取PostgreSQL连接"""
return SQLUtils(source_name='postgresql')
@staticmethod
def get_mysql_connection():
"""获取MySQL连接"""
return SQLUtils(source_name='mysql')