From fb4572262d6f00fe66b9314e53f8025f179e5c35 Mon Sep 17 00:00:00 2001 From: siyaqi Date: Sat, 17 Jan 2026 14:12:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AE=89=E5=85=A8=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=8D=87=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/metasecurity_service.py | 64 ++++++++++++++++--- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/vue-fastapi-backend/module_admin/service/metasecurity_service.py b/vue-fastapi-backend/module_admin/service/metasecurity_service.py index 62c31d4..a62b2a4 100644 --- a/vue-fastapi-backend/module_admin/service/metasecurity_service.py +++ b/vue-fastapi-backend/module_admin/service/metasecurity_service.py @@ -20,6 +20,7 @@ import re from decimal import Decimal import sqlglot from sqlglot.expressions import Table +from sqlglot import exp ,parse_one from typing import Set from sqlparse.tokens import Keyword, DML class MetaSecurityService: @@ -324,13 +325,17 @@ class MetaSecurityService: #3获取sql中涉及的表名 sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) oldStrSql= page_object.sqlStr + if has_pagination(oldStrSql,dsDataResource["type"]): + page_object.isPage=False if page_object.isPage: oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"]) #4.执行原始sql - result = await cls.execute_sql(dbConnent, oldStrSql,"原始") + rawSqlRowCount = await cls.execute_sql(dbConnent, oldStrSql,"原始结果",dsDataResource["type"]) if 3 in role_id_list: + result = await cls.execute_sql(dbConnent, oldStrSql,"原始",dsDataResource["type"]) resultDict={ "ctrlSql": page_object.sqlStr, + "rawSqlRowCount": rawSqlRowCount, "data": result, "message":"数据安全管理员权限" } @@ -356,9 +361,10 @@ class MetaSecurityService: #7.根据行列配置控制原始sql newStrSql =await replace_table_with_subquery(ctrSqlDict,oldStrSql) #8.执行结果 - result = await cls.execute_sql(dbConnent, newStrSql,"控制后") + result = await cls.execute_sql(dbConnent, newStrSql,"控制后",dsDataResource["type"]) resultDict={ "ctrlSql": newStrSql, + "rawSqlRowCount": rawSqlRowCount, "data": result, "tablesRowCol":tablesRowCol } @@ -423,7 +429,7 @@ class MetaSecurityService: except Exception as e: raise RuntimeError(f"连接过程中发生未知错误: {e}") @classmethod - async def execute_sql(cls, engine, sql_query: str, sql_type: str): + async def execute_sql(cls, engine, sql_query: str, sql_type: str,db_type: str): async_session = sessionmaker( engine, class_=AsyncSession, @@ -433,7 +439,11 @@ class MetaSecurityService: try: async with async_session() as session: await session.execute(text("SET statement_timeout = 30000")) - + # ⭐ 原始数量 + if sql_type == "原始结果": + count_sql = cls.build_count_sql(sql_query,db_type) + result = await session.execute(text(count_sql)) + return result.scalar_one() result = await session.execute(text(sql_query)) if result.returns_rows: @@ -455,9 +465,19 @@ class MetaSecurityService: except SQLAlchemyError as e: raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}") - - - + @classmethod + def build_count_sql(cls, sql: str, db_type: str = "POSTGRESQL") -> str: + dialect = db_type_to_sqlglot_dialect(db_type) + parsed = parse_one(sql, read=dialect) + + select = parsed.find(exp.Select) + if select: + select.set("order", None) + select.set("limit", None) + + cleaned_sql = parsed.sql(dialect=dialect) + return f"SELECT COUNT(*) AS cnt FROM ({cleaned_sql}) t" + @classmethod async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): """查询每个表的字段信息,根据数据库类型调整查询""" @@ -483,7 +503,7 @@ class MetaSecurityService: raise ValueError(f"暂不支持数据库类型: {db_type}") # Execute the query for the specific table - result = await cls.execute_sql(dbConnent, query, "字段查询") + result = await cls.execute_sql(dbConnent, query, "字段查询",db_type) # 将结果转换为字典格式 {table_name: ['column1', 'column2', ...]} columns[table_name] = [row["column_name"] for row in result] @@ -499,6 +519,18 @@ def unquote_ident(name: str) -> str: if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]: return name[1:-1] return name +def db_type_to_sqlglot_dialect(db_type: str) -> str: + mapping = { + "MYSQL": "mysql", + "POSTGRESQL": "postgres", + "SQLSERVER": "tsql", + "ORACLE": "oracle", + } + dialect = mapping.get(db_type.upper()) + if not dialect: + raise ValueError(f"Unsupported db_type: {db_type}") + return dialect + def convert_decimal(obj): if isinstance(obj, Decimal): return float(obj) # 或者 str(obj) 来保留精度 @@ -872,6 +904,22 @@ async def test_connection(db_content): await connection.scalar("SELECT 1") except Exception as e: raise Exception("数据源连接失败") from e +def has_pagination(sql: str, db_type: str) -> bool: + sql_upper = sql.upper() + + if db_type in ("MYSQL", "POSTGRESQL"): + return bool(re.search(r"\bLIMIT\b", sql_upper)) + + if db_type == "SQLSERVER": + return ( + "OFFSET" in sql_upper and "FETCH" in sql_upper + ) or "ROW_NUMBER" in sql_upper + + if db_type == "ORACLE": + return "ROWNUM" in sql_upper + + return False + def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> str: """ 生成带分页的 SQL 语句