Browse Source

代码完善

master
siyaqi 3 weeks ago
parent
commit
11b16c3087
  1. 109
      vue-fastapi-backend/module_admin/service/metasecurity_service.py

109
vue-fastapi-backend/module_admin/service/metasecurity_service.py

@ -9,12 +9,13 @@ from module_admin.dao.login_dao import login_by_account
from module_admin.dao.user_dao import UserDao from module_admin.dao.user_dao import UserDao
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError, DBAPIError
from sqlalchemy import text from sqlalchemy import text
from config.env import AppConfig from config.env import AppConfig
import requests import requests
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
import json import json
import asyncio
import re import re
from decimal import Decimal from decimal import Decimal
import sqlparse import sqlparse
@ -317,7 +318,7 @@ class MetaSecurityService:
# dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"} # dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"}
dsDataResource=await get_data_source_tree(request,page_object) dsDataResource=await get_data_source_tree(request,page_object)
# dbConnent= cls.get_db_engine("postgresql",dataParams]) # dbConnent= cls.get_db_engine("postgresql",dataParams])
dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) dbConnent= await cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"])
# await test_connection(dbConnent) # await test_connection(dbConnent)
#3获取sql中涉及的表名 #3获取sql中涉及的表名
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr)
@ -364,64 +365,96 @@ class MetaSecurityService:
def get_db_engine(db_type: str, db_params: dict): async def get_db_engine(db_type: str, db_params: dict):
try: try:
address = db_params['address'] address = db_params['address']
# 1️⃣ 去掉 jdbc 前缀
jdbc_prefixes = { jdbc_prefixes = {
"jdbc:mysql://": len("jdbc:mysql://"), "jdbc:mysql://": len("jdbc:mysql://"),
"jdbc:postgresql://": len("jdbc:postgresql://") "jdbc:postgresql://": len("jdbc:postgresql://")
} }
# Check and remove the matching prefix
for prefix, length in jdbc_prefixes.items(): for prefix, length in jdbc_prefixes.items():
if address.startswith(prefix): if address.startswith(prefix):
address = address[length:] address = address[length:]
db_params['address']=address break
break # Once the correct prefix is found, exit the loop
db_params['address'] = address
# 2️⃣ 构建连接字符串 + 超时
if db_type.lower() == "mysql": if db_type.lower() == "mysql":
conn_str=f"mysql+aiomysql://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}" conn_str = (
print(f"数据库连接字符串: {conn_str}") # 输出调试信息 f"mysql+aiomysql://{db_params['user']}:{db_params['password']}"
dbContent= create_async_engine(conn_str) f"@{db_params['address']}/{db_params['database']}"
return dbContent )
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"connect_timeout": 5} # ⭐ 关键
)
elif db_type.lower() == "postgresql": elif db_type.lower() == "postgresql":
dbContent= create_async_engine(f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}") conn_str = (
return dbContent f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}"
f"@{db_params['address']}/{db_params['database']}"
)
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"timeout": 5} # ⭐ 关键
)
else: else:
raise ValueError("不支持的数据库类型") raise ValueError("不支持的数据库类型")
# 3️⃣ ⭐ 主动发起一次连接校验(不然一定会卡)
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return engine
except asyncio.TimeoutError:
raise ConnectionError("数据库连接超时,请检查地址、端口或防火墙")
except SQLAlchemyError as e: except SQLAlchemyError as e:
# 捕获SQLAlchemy相关的数据库连接错误
raise ConnectionError(f"数据库连接失败: {e}") raise ConnectionError(f"数据库连接失败: {e}")
except Exception as e: except Exception as e:
# 捕获其他非预期的错误 raise RuntimeError(f"连接过程中发生未知错误: {e}")
raise RuntimeError(f"连接过程中发生了未知错误: {e}")
@classmethod @classmethod
async def execute_sql(cls, dbConnent, sql_query: str,sql_type: str): async def execute_sql(cls, engine, sql_query: str, sql_type: str):
# 创建异步会话 async_session = sessionmaker(
async with dbConnent.begin(): engine,
# 获取会话对象 class_=AsyncSession,
async_session = sessionmaker( expire_on_commit=False
dbConnent, class_=AsyncSession, expire_on_commit=False )
)
try:
async with async_session() as session: async with async_session() as session:
try: await session.execute(text("SET statement_timeout = 30000"))
# 执行原始SQL查询
query = text(sql_query)
result = await session.execute(query)
# 获取所有结果 result = await session.execute(text(sql_query))
rows = result.fetchall()
# 获取列名 if result.returns_rows:
rows = result.fetchall()
columns = result.keys() columns = result.keys()
return [dict(zip(columns, row)) for row in rows]
await session.commit()
return []
except DBAPIError as e:
# ⭐ 核心:统一兜 PostgreSQL 超时
sqlstate = getattr(getattr(e, "orig", None), "sqlstate", None)
if sqlstate == "57014":
raise TimeoutError(
"SQL 执行超过 30 秒已被数据库中断,请先查询数据量或使用分页。"
)
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
except SQLAlchemyError as e:
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
# 将每一行转化为字典,键为列名
result_dict = [dict(zip(columns, row)) for row in rows]
# # 使用 convert_decimal 处理数据
# result_dict = [convert_decimal(row) for row in result_dict]
# 转换为 JSON 字符串
return result_dict
except SQLAlchemyError as e:
raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}")
@classmethod @classmethod

Loading…
Cancel
Save