|
|
@ -9,12 +9,13 @@ from module_admin.dao.login_dao import login_by_account |
|
|
from module_admin.dao.user_dao import UserDao |
|
|
from module_admin.dao.user_dao import UserDao |
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
|
|
from sqlalchemy.orm import sessionmaker |
|
|
from sqlalchemy.orm import sessionmaker |
|
|
from sqlalchemy.exc import SQLAlchemyError |
|
|
from sqlalchemy.exc import SQLAlchemyError, DBAPIError |
|
|
from sqlalchemy import text |
|
|
from sqlalchemy import text |
|
|
from config.env import AppConfig |
|
|
from config.env import AppConfig |
|
|
import requests |
|
|
import requests |
|
|
from sqlalchemy.exc import OperationalError |
|
|
from sqlalchemy.exc import OperationalError |
|
|
import json |
|
|
import json |
|
|
|
|
|
import asyncio |
|
|
import re |
|
|
import re |
|
|
from decimal import Decimal |
|
|
from decimal import Decimal |
|
|
import sqlparse |
|
|
import sqlparse |
|
|
@ -317,7 +318,7 @@ class MetaSecurityService: |
|
|
# dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"} |
|
|
# dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"} |
|
|
dsDataResource=await get_data_source_tree(request,page_object) |
|
|
dsDataResource=await get_data_source_tree(request,page_object) |
|
|
# dbConnent= cls.get_db_engine("postgresql",dataParams]) |
|
|
# dbConnent= cls.get_db_engine("postgresql",dataParams]) |
|
|
dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) |
|
|
dbConnent= await cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) |
|
|
# await test_connection(dbConnent) |
|
|
# await test_connection(dbConnent) |
|
|
#3获取sql中涉及的表名 |
|
|
#3获取sql中涉及的表名 |
|
|
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) |
|
|
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) |
|
|
@ -364,64 +365,96 @@ class MetaSecurityService: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_engine(db_type: str, db_params: dict): |
|
|
async def get_db_engine(db_type: str, db_params: dict): |
|
|
try: |
|
|
try: |
|
|
address = db_params['address'] |
|
|
address = db_params['address'] |
|
|
|
|
|
|
|
|
|
|
|
# 1️⃣ 去掉 jdbc 前缀 |
|
|
jdbc_prefixes = { |
|
|
jdbc_prefixes = { |
|
|
"jdbc:mysql://": len("jdbc:mysql://"), |
|
|
"jdbc:mysql://": len("jdbc:mysql://"), |
|
|
"jdbc:postgresql://": len("jdbc:postgresql://") |
|
|
"jdbc:postgresql://": len("jdbc:postgresql://") |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
# Check and remove the matching prefix |
|
|
|
|
|
for prefix, length in jdbc_prefixes.items(): |
|
|
for prefix, length in jdbc_prefixes.items(): |
|
|
if address.startswith(prefix): |
|
|
if address.startswith(prefix): |
|
|
address = address[length:] |
|
|
address = address[length:] |
|
|
db_params['address']=address |
|
|
break |
|
|
break # Once the correct prefix is found, exit the loop |
|
|
|
|
|
|
|
|
db_params['address'] = address |
|
|
|
|
|
|
|
|
|
|
|
# 2️⃣ 构建连接字符串 + 超时 |
|
|
if db_type.lower() == "mysql": |
|
|
if db_type.lower() == "mysql": |
|
|
conn_str=f"mysql+aiomysql://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}" |
|
|
conn_str = ( |
|
|
print(f"数据库连接字符串: {conn_str}") # 输出调试信息 |
|
|
f"mysql+aiomysql://{db_params['user']}:{db_params['password']}" |
|
|
dbContent= create_async_engine(conn_str) |
|
|
f"@{db_params['address']}/{db_params['database']}" |
|
|
return dbContent |
|
|
) |
|
|
|
|
|
engine = create_async_engine( |
|
|
|
|
|
conn_str, |
|
|
|
|
|
pool_pre_ping=True, |
|
|
|
|
|
connect_args={"connect_timeout": 5} # ⭐ 关键 |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
elif db_type.lower() == "postgresql": |
|
|
elif db_type.lower() == "postgresql": |
|
|
dbContent= create_async_engine(f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}@{db_params['address']}/{db_params['database']}") |
|
|
conn_str = ( |
|
|
return dbContent |
|
|
f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}" |
|
|
|
|
|
f"@{db_params['address']}/{db_params['database']}" |
|
|
|
|
|
) |
|
|
|
|
|
engine = create_async_engine( |
|
|
|
|
|
conn_str, |
|
|
|
|
|
pool_pre_ping=True, |
|
|
|
|
|
connect_args={"timeout": 5} # ⭐ 关键 |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
else: |
|
|
raise ValueError("不支持的数据库类型") |
|
|
raise ValueError("不支持的数据库类型") |
|
|
|
|
|
|
|
|
|
|
|
# 3️⃣ ⭐ 主动发起一次连接校验(不然一定会卡) |
|
|
|
|
|
async with engine.connect() as conn: |
|
|
|
|
|
await conn.execute(text("SELECT 1")) |
|
|
|
|
|
|
|
|
|
|
|
return engine |
|
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
raise ConnectionError("数据库连接超时,请检查地址、端口或防火墙") |
|
|
|
|
|
|
|
|
except SQLAlchemyError as e: |
|
|
except SQLAlchemyError as e: |
|
|
# 捕获SQLAlchemy相关的数据库连接错误 |
|
|
|
|
|
raise ConnectionError(f"数据库连接失败: {e}") |
|
|
raise ConnectionError(f"数据库连接失败: {e}") |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
# 捕获其他非预期的错误 |
|
|
raise RuntimeError(f"连接过程中发生未知错误: {e}") |
|
|
raise RuntimeError(f"连接过程中发生了未知错误: {e}") |
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
async def execute_sql(cls, dbConnent, sql_query: str,sql_type: str): |
|
|
async def execute_sql(cls, engine, sql_query: str, sql_type: str): |
|
|
# 创建异步会话 |
|
|
async_session = sessionmaker( |
|
|
async with dbConnent.begin(): |
|
|
engine, |
|
|
# 获取会话对象 |
|
|
class_=AsyncSession, |
|
|
async_session = sessionmaker( |
|
|
expire_on_commit=False |
|
|
dbConnent, class_=AsyncSession, expire_on_commit=False |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
async with async_session() as session: |
|
|
async with async_session() as session: |
|
|
try: |
|
|
await session.execute(text("SET statement_timeout = 30000")) |
|
|
# 执行原始SQL查询 |
|
|
|
|
|
query = text(sql_query) |
|
|
result = await session.execute(text(sql_query)) |
|
|
result = await session.execute(query) |
|
|
|
|
|
|
|
|
|
|
|
# 获取所有结果 |
|
|
|
|
|
rows = result.fetchall() |
|
|
|
|
|
|
|
|
|
|
|
# 获取列名 |
|
|
if result.returns_rows: |
|
|
|
|
|
rows = result.fetchall() |
|
|
columns = result.keys() |
|
|
columns = result.keys() |
|
|
|
|
|
return [dict(zip(columns, row)) for row in rows] |
|
|
|
|
|
|
|
|
|
|
|
await session.commit() |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
except DBAPIError as e: |
|
|
|
|
|
# ⭐ 核心:统一兜 PostgreSQL 超时 |
|
|
|
|
|
sqlstate = getattr(getattr(e, "orig", None), "sqlstate", None) |
|
|
|
|
|
if sqlstate == "57014": |
|
|
|
|
|
raise TimeoutError( |
|
|
|
|
|
"SQL 执行超过 30 秒已被数据库中断,请先查询数据量或使用分页。" |
|
|
|
|
|
) |
|
|
|
|
|
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") |
|
|
|
|
|
|
|
|
# 将每一行转化为字典,键为列名 |
|
|
except SQLAlchemyError as e: |
|
|
result_dict = [dict(zip(columns, row)) for row in rows] |
|
|
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") |
|
|
# # 使用 convert_decimal 处理数据 |
|
|
|
|
|
# result_dict = [convert_decimal(row) for row in result_dict] |
|
|
|
|
|
# 转换为 JSON 字符串 |
|
|
|
|
|
return result_dict |
|
|
|
|
|
except SQLAlchemyError as e: |
|
|
|
|
|
raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
|