|
|
|
@ -20,6 +20,7 @@ import re |
|
|
|
from decimal import Decimal |
|
|
|
import sqlglot |
|
|
|
from sqlglot.expressions import Table |
|
|
|
from sqlglot import exp ,parse_one |
|
|
|
from typing import Set |
|
|
|
from sqlparse.tokens import Keyword, DML |
|
|
|
class MetaSecurityService: |
|
|
|
@ -324,13 +325,17 @@ class MetaSecurityService: |
|
|
|
#3获取sql中涉及的表名 |
|
|
|
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) |
|
|
|
oldStrSql= page_object.sqlStr |
|
|
|
if has_pagination(oldStrSql,dsDataResource["type"]): |
|
|
|
page_object.isPage=False |
|
|
|
if page_object.isPage: |
|
|
|
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"]) |
|
|
|
#4.执行原始sql |
|
|
|
result = await cls.execute_sql(dbConnent, oldStrSql,"原始") |
|
|
|
rawSqlRowCount = await cls.execute_sql(dbConnent, oldStrSql,"原始结果",dsDataResource["type"]) |
|
|
|
if 3 in role_id_list: |
|
|
|
result = await cls.execute_sql(dbConnent, oldStrSql,"原始",dsDataResource["type"]) |
|
|
|
resultDict={ |
|
|
|
"ctrlSql": page_object.sqlStr, |
|
|
|
"rawSqlRowCount": rawSqlRowCount, |
|
|
|
"data": result, |
|
|
|
"message":"数据安全管理员权限" |
|
|
|
} |
|
|
|
@ -356,9 +361,10 @@ class MetaSecurityService: |
|
|
|
#7.根据行列配置控制原始sql |
|
|
|
newStrSql =await replace_table_with_subquery(ctrSqlDict,oldStrSql) |
|
|
|
#8.执行结果 |
|
|
|
result = await cls.execute_sql(dbConnent, newStrSql,"控制后") |
|
|
|
result = await cls.execute_sql(dbConnent, newStrSql,"控制后",dsDataResource["type"]) |
|
|
|
resultDict={ |
|
|
|
"ctrlSql": newStrSql, |
|
|
|
"rawSqlRowCount": rawSqlRowCount, |
|
|
|
"data": result, |
|
|
|
"tablesRowCol":tablesRowCol |
|
|
|
} |
|
|
|
@ -423,7 +429,7 @@ class MetaSecurityService: |
|
|
|
except Exception as e: |
|
|
|
raise RuntimeError(f"连接过程中发生未知错误: {e}") |
|
|
|
@classmethod |
|
|
|
async def execute_sql(cls, engine, sql_query: str, sql_type: str): |
|
|
|
async def execute_sql(cls, engine, sql_query: str, sql_type: str,db_type: str): |
|
|
|
async_session = sessionmaker( |
|
|
|
engine, |
|
|
|
class_=AsyncSession, |
|
|
|
@ -433,7 +439,11 @@ class MetaSecurityService: |
|
|
|
try: |
|
|
|
async with async_session() as session: |
|
|
|
await session.execute(text("SET statement_timeout = 30000")) |
|
|
|
|
|
|
|
# ⭐ 原始数量 |
|
|
|
if sql_type == "原始结果": |
|
|
|
count_sql = cls.build_count_sql(sql_query,db_type) |
|
|
|
result = await session.execute(text(count_sql)) |
|
|
|
return result.scalar_one() |
|
|
|
result = await session.execute(text(sql_query)) |
|
|
|
|
|
|
|
if result.returns_rows: |
|
|
|
@ -455,8 +465,18 @@ class MetaSecurityService: |
|
|
|
|
|
|
|
except SQLAlchemyError as e: |
|
|
|
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") |
|
|
|
@classmethod |
|
|
|
def build_count_sql(cls, sql: str, db_type: str = "POSTGRESQL") -> str: |
|
|
|
dialect = db_type_to_sqlglot_dialect(db_type) |
|
|
|
parsed = parse_one(sql, read=dialect) |
|
|
|
|
|
|
|
select = parsed.find(exp.Select) |
|
|
|
if select: |
|
|
|
select.set("order", None) |
|
|
|
select.set("limit", None) |
|
|
|
|
|
|
|
cleaned_sql = parsed.sql(dialect=dialect) |
|
|
|
return f"SELECT COUNT(*) AS cnt FROM ({cleaned_sql}) t" |
|
|
|
|
|
|
|
@classmethod |
|
|
|
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): |
|
|
|
@ -483,7 +503,7 @@ class MetaSecurityService: |
|
|
|
raise ValueError(f"暂不支持数据库类型: {db_type}") |
|
|
|
|
|
|
|
# Execute the query for the specific table |
|
|
|
result = await cls.execute_sql(dbConnent, query, "字段查询") |
|
|
|
result = await cls.execute_sql(dbConnent, query, "字段查询",db_type) |
|
|
|
|
|
|
|
# 将结果转换为字典格式 {table_name: ['column1', 'column2', ...]} |
|
|
|
columns[table_name] = [row["column_name"] for row in result] |
|
|
|
@ -499,6 +519,18 @@ def unquote_ident(name: str) -> str: |
|
|
|
if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]: |
|
|
|
return name[1:-1] |
|
|
|
return name |
|
|
|
def db_type_to_sqlglot_dialect(db_type: str) -> str: |
|
|
|
mapping = { |
|
|
|
"MYSQL": "mysql", |
|
|
|
"POSTGRESQL": "postgres", |
|
|
|
"SQLSERVER": "tsql", |
|
|
|
"ORACLE": "oracle", |
|
|
|
} |
|
|
|
dialect = mapping.get(db_type.upper()) |
|
|
|
if not dialect: |
|
|
|
raise ValueError(f"Unsupported db_type: {db_type}") |
|
|
|
return dialect |
|
|
|
|
|
|
|
def convert_decimal(obj): |
|
|
|
if isinstance(obj, Decimal): |
|
|
|
return float(obj) # 或者 str(obj) 来保留精度 |
|
|
|
@ -872,6 +904,22 @@ async def test_connection(db_content): |
|
|
|
await connection.scalar("SELECT 1") |
|
|
|
except Exception as e: |
|
|
|
raise Exception("数据源连接失败") from e |
|
|
|
def has_pagination(sql: str, db_type: str) -> bool: |
|
|
|
sql_upper = sql.upper() |
|
|
|
|
|
|
|
if db_type in ("MYSQL", "POSTGRESQL"): |
|
|
|
return bool(re.search(r"\bLIMIT\b", sql_upper)) |
|
|
|
|
|
|
|
if db_type == "SQLSERVER": |
|
|
|
return ( |
|
|
|
"OFFSET" in sql_upper and "FETCH" in sql_upper |
|
|
|
) or "ROW_NUMBER" in sql_upper |
|
|
|
|
|
|
|
if db_type == "ORACLE": |
|
|
|
return "ROWNUM" in sql_upper |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> str: |
|
|
|
""" |
|
|
|
生成带分页的 SQL 语句 |
|
|
|
|