Browse Source

数据安全接口升级

master
siyaqi 2 weeks ago
parent
commit
fb4572262d
  1. 64
      vue-fastapi-backend/module_admin/service/metasecurity_service.py

64
vue-fastapi-backend/module_admin/service/metasecurity_service.py

@ -20,6 +20,7 @@ import re
from decimal import Decimal
import sqlglot
from sqlglot.expressions import Table
from sqlglot import exp ,parse_one
from typing import Set
from sqlparse.tokens import Keyword, DML
class MetaSecurityService:
@ -324,13 +325,17 @@ class MetaSecurityService:
#3获取sql中涉及的表名
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr)
oldStrSql= page_object.sqlStr
if has_pagination(oldStrSql,dsDataResource["type"]):
page_object.isPage=False
if page_object.isPage:
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"])
#4.执行原始sql
result = await cls.execute_sql(dbConnent, oldStrSql,"原始")
rawSqlRowCount = await cls.execute_sql(dbConnent, oldStrSql,"原始结果",dsDataResource["type"])
if 3 in role_id_list:
result = await cls.execute_sql(dbConnent, oldStrSql,"原始",dsDataResource["type"])
resultDict={
"ctrlSql": page_object.sqlStr,
"rawSqlRowCount": rawSqlRowCount,
"data": result,
"message":"数据安全管理员权限"
}
@ -356,9 +361,10 @@ class MetaSecurityService:
#7.根据行列配置控制原始sql
newStrSql =await replace_table_with_subquery(ctrSqlDict,oldStrSql)
#8.执行结果
result = await cls.execute_sql(dbConnent, newStrSql,"控制后")
result = await cls.execute_sql(dbConnent, newStrSql,"控制后",dsDataResource["type"])
resultDict={
"ctrlSql": newStrSql,
"rawSqlRowCount": rawSqlRowCount,
"data": result,
"tablesRowCol":tablesRowCol
}
@ -423,7 +429,7 @@ class MetaSecurityService:
except Exception as e:
raise RuntimeError(f"连接过程中发生未知错误: {e}")
@classmethod
async def execute_sql(cls, engine, sql_query: str, sql_type: str):
async def execute_sql(cls, engine, sql_query: str, sql_type: str,db_type: str):
async_session = sessionmaker(
engine,
class_=AsyncSession,
@ -433,7 +439,11 @@ class MetaSecurityService:
try:
async with async_session() as session:
await session.execute(text("SET statement_timeout = 30000"))
# ⭐ 原始数量
if sql_type == "原始结果":
count_sql = cls.build_count_sql(sql_query,db_type)
result = await session.execute(text(count_sql))
return result.scalar_one()
result = await session.execute(text(sql_query))
if result.returns_rows:
@ -455,9 +465,19 @@ class MetaSecurityService:
except SQLAlchemyError as e:
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
@classmethod
def build_count_sql(cls, sql: str, db_type: str = "POSTGRESQL") -> str:
dialect = db_type_to_sqlglot_dialect(db_type)
parsed = parse_one(sql, read=dialect)
select = parsed.find(exp.Select)
if select:
select.set("order", None)
select.set("limit", None)
cleaned_sql = parsed.sql(dialect=dialect)
return f"SELECT COUNT(*) AS cnt FROM ({cleaned_sql}) t"
@classmethod
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str):
"""查询每个表的字段信息,根据数据库类型调整查询"""
@ -483,7 +503,7 @@ class MetaSecurityService:
raise ValueError(f"暂不支持数据库类型: {db_type}")
# Execute the query for the specific table
result = await cls.execute_sql(dbConnent, query, "字段查询")
result = await cls.execute_sql(dbConnent, query, "字段查询",db_type)
# 将结果转换为字典格式 {table_name: ['column1', 'column2', ...]}
columns[table_name] = [row["column_name"] for row in result]
@ -499,6 +519,18 @@ def unquote_ident(name: str) -> str:
if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]:
return name[1:-1]
return name
def db_type_to_sqlglot_dialect(db_type: str) -> str:
mapping = {
"MYSQL": "mysql",
"POSTGRESQL": "postgres",
"SQLSERVER": "tsql",
"ORACLE": "oracle",
}
dialect = mapping.get(db_type.upper())
if not dialect:
raise ValueError(f"Unsupported db_type: {db_type}")
return dialect
def convert_decimal(obj):
if isinstance(obj, Decimal):
return float(obj) # 或者 str(obj) 来保留精度
@ -872,6 +904,22 @@ async def test_connection(db_content):
await connection.scalar("SELECT 1")
except Exception as e:
raise Exception("数据源连接失败") from e
def has_pagination(sql: str, db_type: str) -> bool:
sql_upper = sql.upper()
if db_type in ("MYSQL", "POSTGRESQL"):
return bool(re.search(r"\bLIMIT\b", sql_upper))
if db_type == "SQLSERVER":
return (
"OFFSET" in sql_upper and "FETCH" in sql_upper
) or "ROW_NUMBER" in sql_upper
if db_type == "ORACLE":
return "ROWNUM" in sql_upper
return False
def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> str:
"""
生成带分页的 SQL 语句

Loading…
Cancel
Save