diff --git a/vue-fastapi-backend/module_admin/service/metasecurity_service.py b/vue-fastapi-backend/module_admin/service/metasecurity_service.py index 8231941..518a3fd 100644 --- a/vue-fastapi-backend/module_admin/service/metasecurity_service.py +++ b/vue-fastapi-backend/module_admin/service/metasecurity_service.py @@ -9,12 +9,13 @@ from module_admin.dao.login_dao import login_by_account from module_admin.dao.user_dao import UserDao from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import SQLAlchemyError, DBAPIError from sqlalchemy import text from config.env import AppConfig import requests from sqlalchemy.exc import OperationalError import json +import asyncio import re from decimal import Decimal import sqlparse @@ -317,7 +318,7 @@ class MetaSecurityService: # dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"} dsDataResource=await get_data_source_tree(request,page_object) # dbConnent= cls.get_db_engine("postgresql",dataParams]) - dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) + dbConnent= await cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) # await test_connection(dbConnent) #3获取sql中涉及的表名 sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) @@ -364,64 +365,96 @@ class MetaSecurityService: - def get_db_engine(db_type: str, db_params: dict): - try: + async def get_db_engine(db_type: str, db_params: dict): + try: address = db_params['address'] + + # 1️⃣ 去掉 jdbc 前缀 jdbc_prefixes = { "jdbc:mysql://": len("jdbc:mysql://"), "jdbc:postgresql://": len("jdbc:postgresql://") } - - # Check and remove the matching prefix for prefix, length in jdbc_prefixes.items(): if address.startswith(prefix): address = address[length:] - db_params['address']=address - break # Once the correct prefix is found, exit the loop + break + + db_params['address'] = address + + # 2️⃣ 构建连接字符串 + 超时 if db_type.lower() == "mysql": - conn_str=f"mysql+aiomysql://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}" - print(f"数据库连接字符串: {conn_str}") # 输出调试信息 - dbContent= create_async_engine(conn_str) - return dbContent + conn_str = ( + f"mysql+aiomysql://{db_params['user']}:{db_params['password']}" + f"@{db_params['address']}/{db_params['database']}" + ) + engine = create_async_engine( + conn_str, + pool_pre_ping=True, + connect_args={"connect_timeout": 5} # ⭐ 关键 + ) + elif db_type.lower() == "postgresql": - dbContent= create_async_engine(f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}") - return dbContent + conn_str = ( + f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}" + f"@{db_params['address']}/{db_params['database']}" + ) + engine = create_async_engine( + conn_str, + pool_pre_ping=True, + connect_args={"timeout": 5} # ⭐ 关键 + ) + else: - raise ValueError("不支持的数据库类型") + raise ValueError("不支持的数据库类型") + + # 3️⃣ ⭐ 主动发起一次连接校验(不然一定会卡) + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + + return engine + + except asyncio.TimeoutError: + raise ConnectionError("数据库连接超时,请检查地址、端口或防火墙") + except SQLAlchemyError as e: - # 捕获SQLAlchemy相关的数据库连接错误 raise ConnectionError(f"数据库连接失败: {e}") + except Exception as e: - # 捕获其他非预期的错误 - raise RuntimeError(f"连接过程中发生了未知错误: {e}") + raise RuntimeError(f"连接过程中发生未知错误: {e}") @classmethod - async def execute_sql(cls, dbConnent, sql_query: str,sql_type: str): - # 创建异步会话 - async with dbConnent.begin(): - # 获取会话对象 - async_session = sessionmaker( - dbConnent, class_=AsyncSession, expire_on_commit=False - ) + async def execute_sql(cls, engine, sql_query: str, sql_type: str): + async_session = sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False + ) + + try: async with async_session() as session: - try: - # 执行原始SQL查询 - query = text(sql_query) - result = await session.execute(query) - - # 获取所有结果 - rows = result.fetchall() + await session.execute(text("SET statement_timeout = 30000")) + + result = await session.execute(text(sql_query)) - # 获取列名 + if result.returns_rows: + rows = result.fetchall() columns = result.keys() + return [dict(zip(columns, row)) for row in rows] + + await session.commit() + return [] + + except DBAPIError as e: + # ⭐ 核心:统一兜 PostgreSQL 超时 + sqlstate = getattr(getattr(e, "orig", None), "sqlstate", None) + if sqlstate == "57014": + raise TimeoutError( + "SQL 执行超过 30 秒已被数据库中断,请先查询数据量或使用分页。" + ) + raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") - # 将每一行转化为字典,键为列名 - result_dict = [dict(zip(columns, row)) for row in rows] - # # 使用 convert_decimal 处理数据 - # result_dict = [convert_decimal(row) for row in result_dict] - # 转换为 JSON 字符串 - return result_dict - except SQLAlchemyError as e: - raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}") + except SQLAlchemyError as e: + raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") + @classmethod