You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

994 lines
44 KiB

from fastapi import Request
from exceptions.exception import ServiceException
from module_admin.dao.metaSecurity_dao import MetaSecurityDao
from module_admin.entity.vo.common_vo import CrudResponseModel
from module_admin.entity.vo.metasecurity_vo import MetaSecurityColModel, MetaSecurityRowModel,DeleteMetaSecurityModel,MetaSecurityApiModel
from utils.common_util import CamelCaseUtil
import uuid
from module_admin.dao.login_dao import login_by_account
from module_admin.dao.user_dao import UserDao
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
3 weeks ago
from sqlalchemy.exc import SQLAlchemyError, DBAPIError
from sqlalchemy import text
from config.env import AppConfig
import requests
11 months ago
from sqlalchemy.exc import OperationalError
import json
3 weeks ago
import asyncio
import re
from decimal import Decimal
2 months ago
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Function, Token
from sqlparse.tokens import Keyword, DML
class MetaSecurityService:
"""
数据源安全管理模块服务层
"""
@classmethod
async def get_meta_security_col_list_services(
cls, query_db: AsyncSession, query_object: MetaSecurityColModel, is_page: bool = False
):
"""
获取列配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 列配置列表信息对象
"""
col_list_result = await MetaSecurityDao.get_meta_security_col_list(query_db, query_object, is_page)
return col_list_result
@classmethod
async def get_meta_security_row_list_services(
cls, query_db: AsyncSession, query_object: MetaSecurityRowModel, is_page: bool = False
):
"""
获取行配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 行配置列表信息对象
"""
row_list_result = await MetaSecurityDao.get_meta_security_row_list(query_db, query_object, is_page)
return row_list_result
2 months ago
@classmethod
async def get_meta_mdlName_list_services(
cls, query_db: AsyncSession, ssys_id: int
):
"""
获取行配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 行配置列表信息对象
"""
row_list_result = await MetaSecurityDao.get_schema_by_system(query_db, ssys_id)
return row_list_result
@classmethod
async def get_meta_security_col_by_id_services(cls, query_db: AsyncSession, colId: str):
"""
获取列配置详细信息service
:param query_db: orm对象
:param colId: 列配置ID
:return: 列配置详细信息对象
"""
col = await MetaSecurityDao.get_meta_security_col_by_id(query_db, colId)
if col:
result = MetaSecurityColModel(**CamelCaseUtil.transform_result(col))
else:
result = MetaSecurityColModel(**dict())
return result
@classmethod
async def get_meta_security_row_by_id_services(cls, query_db: AsyncSession, rowId: str):
"""
获取行配置详细信息service
:param query_db: orm对象
:param rowId: 行配置ID
:return: 行配置详细信息对象
"""
row = await MetaSecurityDao.get_meta_security_row_by_id(query_db, rowId)
if row:
result = MetaSecurityRowModel(**CamelCaseUtil.transform_result(row))
else:
result = MetaSecurityRowModel(**dict())
return result
@classmethod
async def add_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityColModel):
"""
新增列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 新增的列配置对象
:return: 新增列配置校验结果
"""
try:
if isinstance(page_object.obj_value, str) and page_object.obj_value:
obj_values = page_object.obj_value.split(",")
obj_names = page_object.obj_name.split(",")
for value, name in zip(obj_values, obj_names):
# 创建新的 page_object 实例,避免修改原始对象
new_page_object = MetaSecurityColModel(**page_object.model_dump(by_alias=True))
new_page_object.obj_value = value.strip() # 去除空格并赋值
new_page_object.obj_name = name.strip() # 去除空格并赋值
new_page_object.colId = str(uuid.uuid4())
# 调用 DAO 方法插入数据
await MetaSecurityDao.add_meta_security_col(query_db, new_page_object)
await query_db.commit()
return CrudResponseModel(is_success=True, message='新增列配置成功')
except Exception as e:
await query_db.rollback()
raise e
@classmethod
async def add_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityRowModel):
"""
新增行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 新增的行配置对象
:return: 新增行配置校验结果
"""
try:
if isinstance(page_object.obj_value, str) and page_object.obj_value:
obj_values = page_object.obj_value.split(",")
obj_names = page_object.obj_name.split(",")
for value, name in zip(obj_values, obj_names):
# 创建新的 page_object 实例,避免修改原始对象
new_page_object = MetaSecurityRowModel(**page_object.model_dump(by_alias=True))
new_page_object.obj_value = value.strip() # 去除空格并赋值
new_page_object.obj_name = name.strip() # 去除空格并赋值
new_page_object.rowId = str(uuid.uuid4())
# 调用 DAO 方法插入数据
await MetaSecurityDao.add_meta_security_row(query_db, new_page_object)
await query_db.commit()
# 缓存相关操作,如果需要
# await request.app.state.redis.set(...)
return CrudResponseModel(is_success=True, message='新增行配置成功')
except Exception as e:
await query_db.rollback()
raise e
@classmethod
async def col_detail_services(cls, query_db: AsyncSession, col: str):
"""
获取参数配置详细信息service
:param query_db: orm对象
:param config_id: 参数配置id
:return: 参数配置id对应的信息
"""
config = await MetaSecurityDao.get_meta_security_col_by_id(query_db, col)
if config:
result = MetaSecurityColModel(**CamelCaseUtil.transform_result(config))
else:
result = MetaSecurityColModel(**dict())
return result
@classmethod
async def row_detail_services(cls, query_db: AsyncSession, row_id: str):
"""
获取参数配置详细信息service
:param query_db: orm对象
:param config_id: 参数配置id
:return: 参数配置id对应的信息
"""
config = await MetaSecurityDao.get_meta_security_row_by_id(query_db, row_id)
if config:
result = MetaSecurityRowModel(**CamelCaseUtil.transform_result(config))
else:
result = MetaSecurityRowModel(**dict())
return result
@classmethod
async def edit_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityColModel):
"""
编辑列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 编辑的列配置对象
:return: 编辑列配置校验结果
"""
edit_col = page_object.model_dump(exclude_unset=True)
col_info = await cls.get_meta_security_col_by_id_services(query_db, page_object.colId)
if col_info:
try:
await MetaSecurityDao.update_meta_security_col(query_db, edit_col)
await query_db.commit()
# 缓存更新,如果需要
# await request.app.state.redis.set(...)
return CrudResponseModel(is_success=True, message='编辑列配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message=f'列配置{page_object.colId}不存在')
@classmethod
async def edit_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityRowModel):
"""
编辑行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 编辑的行配置对象
:return: 编辑行配置校验结果
"""
edit_row = page_object.model_dump(exclude_unset=True)
row_info = await cls.get_meta_security_row_by_id_services(query_db, page_object.rowId)
if row_info:
try:
await MetaSecurityDao.update_meta_security_row(query_db, edit_row)
await query_db.commit()
return CrudResponseModel(is_success=True, message='编辑行配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message=f'行配置{page_object.rowId}不存在')
@classmethod
async def delete_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: DeleteMetaSecurityModel):
"""
删除列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 删除列配置对象
:return: 删除列配置校验结果
"""
if page_object.metaSecurity_ids:
col_id_list = page_object.metaSecurity_ids.split(',')
try:
for col_id in col_id_list:
col_info = await cls.get_meta_security_col_by_id_services(query_db, col_id)
if col_info:
# 校验不能删除的系统内置列
await MetaSecurityDao.delete_meta_security_col(query_db, col_id)
await query_db.commit()
return CrudResponseModel(is_success=True, message='删除列配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message='传入列配置ID为空')
@classmethod
async def delete_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: DeleteMetaSecurityModel):
"""
删除行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 删除行配置对象
:return: 删除行配置校验结果
"""
if page_object.metaSecurity_ids:
row_id_list = page_object.metaSecurity_ids.split(',')
try:
for row_id in row_id_list:
row_info = await cls.get_meta_security_row_by_id_services(query_db, row_id)
if row_info:
await MetaSecurityDao.delete_meta_security_row(query_db, row_id)
await query_db.commit()
return CrudResponseModel(is_success=True, message='删除行配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message='传入行配置ID为空')
@classmethod
async def getMetaSercuitybysql(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityApiModel):
#1.校验用户
if not page_object.username:
raise ServiceException(data='', message='用户名不能为空!')
user = await login_by_account(query_db, page_object.username)
if not user:
raise ServiceException(data='', message='用户不存在')
if not page_object.password == user[0].password:
raise ServiceException(data='', message='用户密码错误!')
forbidden_keywords = ["UPDATE", "DELETE", "INSERT", "DROP", "ALTER", "TRUNCATE"]
pattern = re.compile(r"\b(" + "|".join(forbidden_keywords) + r")\b", re.IGNORECASE)
if pattern.search(page_object.sqlStr):
raise ServiceException(data='', message='SQL 中包含敏感词(UPDATE, DELETE, INSERT, DROP, ALTER, TRUNCATE),禁止执行!')
query_user = await UserDao.get_user_by_id(query_db, user_id=user[0].user_id)
role_id_list = [item.role_id for item in query_user.get('user_role_info')]
#2.测试数据源连接是否正常
# mysql
# dataParams ={"user":"dbf","password":"1q2w3e4r","address":"jdbc:mysql://47.113.147.166:3306","database":"dash_test_w","jdbcUrl":"jdbc:mysql://47.113.147.166:3306/dash_test_w","driverClassName":"com.mysql.cj.jdbc.Driver","validationQuery":"select 1"}
# postgresql
# dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"}
dsDataResource=await get_data_source_tree(request,page_object)
# dbConnent= cls.get_db_engine("postgresql",dataParams])
3 weeks ago
dbConnent= await cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"])
11 months ago
# await test_connection(dbConnent)
#3获取sql中涉及的表名
2 months ago
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr)
oldStrSql= page_object.sqlStr
if page_object.isPage:
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"])
#4.执行原始sql
result = await cls.execute_sql(dbConnent, oldStrSql,"原始")
if 3 in role_id_list:
resultDict={
"ctrlSql": page_object.sqlStr,
"data": result,
"message":"数据安全管理员权限"
}
return resultDict
#5.根据表名获取数据库中的字段名
table_columns = await cls.get_columns_from_tables(dbConnent, sqlScheamAndTable,dsDataResource["type"],)
#6.查询用户及该用户角色下的所有行列配置
tablesRowCol = {}
# 遍历每个表名,获取对应的配置
for table_name in sqlScheamAndTable:
table_configs = await get_table_configs(query_db, page_object, user, role_id_list, table_name)
tablesRowCol[table_name] = table_configs
11 months ago
# 返回最终的结果字典
ctrSqlDict = await generate_sql(tablesRowCol,table_columns)
oldStrSql= page_object.sqlStr
10 months ago
if page_object.isPage:
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"])
#7.根据行列配置控制原始sql
newStrSql =await replace_table_with_subquery(ctrSqlDict,oldStrSql)
#8.执行结果
11 months ago
result = await cls.execute_sql(dbConnent, newStrSql,"控制后")
resultDict={
"ctrlSql": newStrSql,
"data": result,
"tablesRowCol":tablesRowCol
}
return resultDict
3 weeks ago
async def get_db_engine(db_type: str, db_params: dict):
try:
address = db_params['address']
3 weeks ago
# 1️⃣ 去掉 jdbc 前缀
jdbc_prefixes = {
"jdbc:mysql://": len("jdbc:mysql://"),
"jdbc:postgresql://": len("jdbc:postgresql://")
}
for prefix, length in jdbc_prefixes.items():
if address.startswith(prefix):
address = address[length:]
3 weeks ago
break
db_params['address'] = address
# 2️⃣ 构建连接字符串 + 超时
if db_type.lower() == "mysql":
3 weeks ago
conn_str = (
f"mysql+aiomysql://{db_params['user']}:{db_params['password']}"
f"@{db_params['address']}/{db_params['database']}"
)
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"connect_timeout": 5} # ⭐ 关键
)
elif db_type.lower() == "postgresql":
3 weeks ago
conn_str = (
f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}"
f"@{db_params['address']}/{db_params['database']}"
)
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"timeout": 5} # ⭐ 关键
)
else:
3 weeks ago
raise ValueError("不支持的数据库类型")
# 3️⃣ ⭐ 主动发起一次连接校验(不然一定会卡)
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return engine
except asyncio.TimeoutError:
raise ConnectionError("数据库连接超时,请检查地址、端口或防火墙")
except SQLAlchemyError as e:
raise ConnectionError(f"数据库连接失败: {e}")
3 weeks ago
except Exception as e:
3 weeks ago
raise RuntimeError(f"连接过程中发生未知错误: {e}")
@classmethod
3 weeks ago
async def execute_sql(cls, engine, sql_query: str, sql_type: str):
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
try:
async with async_session() as session:
3 weeks ago
await session.execute(text("SET statement_timeout = 30000"))
result = await session.execute(text(sql_query))
3 weeks ago
if result.returns_rows:
rows = result.fetchall()
columns = result.keys()
3 weeks ago
return [dict(zip(columns, row)) for row in rows]
await session.commit()
return []
except DBAPIError as e:
# ⭐ 核心:统一兜 PostgreSQL 超时
sqlstate = getattr(getattr(e, "orig", None), "sqlstate", None)
if sqlstate == "57014":
raise TimeoutError(
"SQL 执行超过 30 秒已被数据库中断,请先查询数据量或使用分页。"
)
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
3 weeks ago
except SQLAlchemyError as e:
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
@classmethod
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str):
"""查询每个表的字段信息,根据数据库类型调整查询"""
columns = {}
query=""
for table_name in table_names:
schema, table = table_name.split(".")
if db_type.lower() == "postgresql":
# PostgreSQL: 查询指定 schema 下的表字段
query = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}' AND table_name = '{table}'
"""
elif db_type.lower() == "mysql":
# MySQL: 直接查询表字段(MySQL 没有 schema 的概念)
query = f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{table}'
"""
else:
raise ValueError(f"暂不支持数据库类型: {db_type}")
# Execute the query for the specific table
result = await cls.execute_sql(dbConnent, query, "字段查询")
# 将结果转换为字典格式 {table_name: ['column1', 'column2', ...]}
columns[table_name] = [row["column_name"] for row in result]
return columns
def unquote_ident(name: str) -> str:
"""
去除标识符首尾的成对引号" ' `),例如:
'"public"' -> public, '`schema`' -> schema
"""
if not name:
return name
if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]:
return name[1:-1]
return name
def convert_decimal(obj):
if isinstance(obj, Decimal):
return float(obj) # 或者 str(obj) 来保留精度
elif isinstance(obj, dict):
# 递归处理字典中的每个值
return {key: convert_decimal(value) for key, value in obj.items()}
elif isinstance(obj, list):
# 递归处理列表中的每个元素
return [convert_decimal(item) for item in obj]
return obj # 返回非 Decimal、dict 或 list 的值
async def get_table_configs(query_db, page_object, user, role_id_list, table_name):
parts = table_name.split(".")
schema, table = parts
# 获取用户的列配置
user_col_list = await MetaSecurityDao.get_api_col_list(
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id
)
# 获取角色的列配置
role_col_list = []
for role_id in role_id_list:
role_cols = await MetaSecurityDao.get_api_col_list(
query_db, page_object.dbRCode, schema,table, '1', role_id
)
role_col_list.extend(role_cols) # 将每个角色的列配置合并到列表中
# 获取用户的行配置
user_row_list = await MetaSecurityDao.get_api_row_list(
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id
)
# 获取角色的行配置
role_row_list = []
for role_id in role_id_list:
role_rows = await MetaSecurityDao.get_api_row_list(
query_db, page_object.dbRCode,schema,table, '1', role_id
)
role_row_list.extend(role_rows) # 将每个角色的行配置合并到列表中
11 months ago
isHave = any([
user_col_list,
role_col_list,
user_row_list,
role_row_list
])
return {
"user_col_list": user_col_list,
"role_col_list": role_col_list,
"role_row_list": role_row_list,
11 months ago
"user_row_list": user_row_list,
"isHave":isHave
}
# async def generate_sql(tablesRowCol:dict, table_columns:dict):
# sql_queries = {}
11 months ago
# # 1. 列控制
# # 遍历每个表
# no_configTable_name=""
# for table_name, table_configs in tablesRowCol.items():
# if not table_configs.get("isHave", False):
# no_configTable_name += table_name + ","
# if no_configTable_name:
# no_configTable_name = no_configTable_name.rstrip(',')
# raise ValueError(f"表:{no_configTable_name}均未配置行列数据安全")
# for table_name, config in tablesRowCol.items():
# # 获取该表的字段名
# columns = {col.lower(): col for col in table_columns[table_name]} # 将字段名转为小写
# # 初始化 SELECT 部分:用字典存储字段名,值是 null 字段名
# select_columns = {col: f"null as {col}" for col in columns}
# # 处理角色列配置
# for col in config["role_col_list"]:
# # If dbCName is "ALL", handle it as a special case
# if col.dbCName == "ALL":
# if col.ctrl_type == '0': # If ctrl_type is '0', prefix all columns with null
# for db_column in columns: # Assuming 'user' is the table name
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == '1': # If ctrl_type is '1', use actual column names
# for db_column in columns:
# select_columns[db_column] = db_column # 使用实际字段名
# else:
# # Handle specific columns listed in dbCName
# db_columns = [db_column.strip().lower() for db_column in col.dbCName.split(",")]
# for db_column in db_columns:
# db_column = db_column.strip()
# if db_column in columns: # Check if the column exists in the table
# if col.ctrl_type == '0': # If ctrl_type is '0', prefix with null
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == '1': # If ctrl_type is '1', use actual column name
# select_columns[db_column] = db_column # 使用实际字段名
# # 处理用户列配置
# for col in config["user_col_list"]:
# if col.dbCName == "ALL": # 如果 dbCName 为 "ALL"
# if col.ctrl_type == "0": # ctrlType 为 0,字符串字段
# for db_column in columns: # 对所有字段加上 null
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == "1": # ctrlType 为 1,实际数据库字段
# for db_column in columns: # 使用实际字段名,不加 null
# select_columns[db_column] = db_column # 使用实际字段名
# else: # 处理 dbCName 不为 "ALL" 的情况
# db_columns = [db_column.strip().lower() for db_column in col.dbCName.split(",")]
# for db_column in db_columns:
# db_column = db_column.strip()
# if db_column in columns:
# if col.ctrl_type == "0":
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == "1":
# select_columns[db_column] = db_column # 使用实际字段名
# # 生成 SQL 查询
# sql_queries[table_name] = f"SELECT {', '.join(select_columns.values())} FROM {table_name}"
# # 2.行控制
# select_rows={}
# # 处理角色行配置
# for row in config["role_row_list"]:
# # 仅仅对固定值有效,不加行限制
# if row.ctrl_value == "ALL" and row.ctrl_type == '0':
# # 控制方式 --固定值
# select_rows[row.dbCName] = ""
# else:
# if row.ctrl_type == '0':
# # row.ctrl_value 是逗号分隔的字符串时,改为 IN 语句
# if "," in row.ctrl_value:
# # 将 ctrl_value 按逗号分割,并用单引号包裹每个值
# values = [f"'{value.strip()}'" for value in row.ctrl_value.split(",")]
# select_rows[row.dbCName] = f"{row.dbCName} IN ({', '.join(values)})"
# else:
# select_rows[row.dbCName] = f"{row.dbCName} = '{row.ctrl_value}'"
# if row.ctrl_type == '1':
# tab_col_value=row.ctrl_value.split(".")
# if len(tab_col_value) != 2:
# raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值")
# select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')"
# # 处理用户行配置
# for row in config["user_row_list"]:
# # 仅仅对固定值有效,不加行限制
# if row.ctrl_value == "ALL" and row.ctrl_type == '0':
# # 控制方式 --固定值
# select_rows[row.dbCName] = ""
# else:
# if row.ctrl_type == '0':
# # row.obj_value 是逗号分隔的字符串时,改为 IN 语句
# if "," in row.ctrl_value:
# # 将 obj_value 按逗号分割,并用单引号包裹每个值
# values = [f"'{value.strip()}'" for value in row.ctrl_value.split(",")]
# select_rows[row.dbCName] = f"{row.dbCName} IN ({', '.join(values)})"
# else:
# select_rows[row.dbCName] = f"{row.dbCName} = '{row.ctrl_value}'"
# if row.ctrl_type == '1':
# tab_col_value=row.ctrl_value.split(".")
# if len(tab_col_value) != 2:
# raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值")
# select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')"
# if select_rows.values():
# where_conditions = " AND ".join(select_rows.values())
# if where_conditions:
# sql_queries[table_name] += " WHERE " + where_conditions
# else:
# sql_queries[table_name] += " WHERE 1 = 0"
# return sql_queries
async def generate_sql(tablesRowCol: dict, table_columns: dict):
sql_queries = {}
# ========= 0. 校验是否存在未配置安全策略的表 =========
no_config_tables = [
table_name
for table_name, cfg in tablesRowCol.items()
if not cfg.get("isHave", False)
]
if no_config_tables:
4 weeks ago
raise ValueError(f"您没有查看{','.join(no_config_tables)} 表记录的权限,请联系管理员配置相关行/列 数据安全策略")
# ========= 1. 遍历每个表 =========
for table_name, config in tablesRowCol.items():
# 字段映射:小写 → 原始字段名
columns = {col.lower(): col for col in table_columns[table_name]}
# ====================================================
# 2. 列控制(不可见优先)
# ====================================================
# 0 = 不可见,1 = 可见,None = 未配置(默认不可见)
column_visibility = {col: None for col in columns}
def set_visibility(col_name: str, ctrl_type: str):
"""
不可见(ctrl_type=0) 优先级最高
"""
if ctrl_type == '0':
column_visibility[col_name] = '0'
elif ctrl_type == '1':
if column_visibility[col_name] != '0':
column_visibility[col_name] = '1'
def handle_col_config(col_cfg_list):
for col in col_cfg_list:
if col.dbCName == "ALL":
for db_col in columns:
set_visibility(db_col, col.ctrl_type)
else:
db_cols = [c.strip().lower() for c in col.dbCName.split(",")]
for db_col in db_cols:
if db_col in columns:
set_visibility(db_col, col.ctrl_type)
# 角色列 + 用户列
handle_col_config(config.get("role_col_list", []))
handle_col_config(config.get("user_col_list", []))
# 生成 SELECT 字段
select_columns = []
for col in columns:
if column_visibility[col] == '1':
select_columns.append(col)
else:
select_columns.append(f"null as {col}")
sql = f"SELECT {', '.join(select_columns)} FROM {table_name}"
# ====================================================
# 3. 行控制
# ====================================================
where_conditions = []
1 month ago
allow_all_rows = False
def build_row_condition(row):
# 固定值 & ALL → 不加限制
if row.ctrl_type == '0' and row.ctrl_value == "ALL":
1 month ago
return "__ALLOW_ALL__"
# 固定值
if row.ctrl_type == '0':
if "," in row.ctrl_value:
values = [f"'{v.strip()}'" for v in row.ctrl_value.split(",")]
return f"{row.dbCName} IN ({', '.join(values)})"
return f"{row.dbCName} = '{row.ctrl_value}'"
# 表字段
if row.ctrl_type == '1':
tab_col = row.ctrl_value.split(".")
if len(tab_col) != 2:
raise RuntimeError(
f"{row.dbCName} 字段控制类型为表字段,但未维护正确的值"
)
table, column = tab_col
return (
f"{row.dbCName} IN ("
f"SELECT {column} FROM {row.dbSName}.{table} "
f"WHERE user_id = '1')"
)
return None
def handle_row_config(row_cfg_list):
1 month ago
nonlocal allow_all_rows
for row in row_cfg_list:
condition = build_row_condition(row)
1 month ago
if condition == "__ALLOW_ALL__":
allow_all_rows = True
continue
if condition:
where_conditions.append(condition)
# 角色行 + 用户行
handle_row_config(config.get("role_row_list", []))
handle_row_config(config.get("user_row_list", []))
# ====================================================
# 4. WHERE 拼接(无行权限则拒绝访问)
# ====================================================
if where_conditions:
sql += " WHERE " + " AND ".join(where_conditions)
1 month ago
elif allow_all_rows:
pass # 不拼 WHERE,等
else:
sql += " WHERE 1 = 0"
sql_queries[table_name] = sql
return sql_queries
async def replace_table_with_subquery(ctrSqlDict, oldStrSql):
"""
SQL 中的表替换成子查询并自动生成别名同时把字段引用替换为别名.字段
"""
table_alias_map = {} # 存储表名和别名的映射
for table_name, subquery in ctrSqlDict.items():
# 1️⃣ 匹配 FROM / JOIN 中的表名及别名(不使用 lookbehind)
from_join_pattern = (
r'\b(FROM|JOIN)\s+' # 捕获关键字
r'((?:[a-zA-Z_][a-zA-Z0-9_]*\.)?' # 模式名(可选)
+ re.escape(table_name) + r')' # 表名
r'(\s+(?:AS\s+)?(\w+))?' # 可选别名
)
# 替换 FROM / JOIN 部分
def from_join_replace(match):
keyword = match.group(1) # FROM / JOIN
original_table = match.group(2)
alias_part = match.group(3) # " AS xxx" 或 " xxx"
alias_name = match.group(4) # xxx
sql_keywords = {
"SELECT", "INSERT", "UPDATE", "DELETE", "MERGE", "TRUNCATE",
"VALUES", "RETURNING", "FROM", "WHERE", "GROUP", "HAVING", "ORDER",
"LIMIT", "OFFSET", "DISTINCT", "ALL", "UNION", "INTERSECT", "EXCEPT",
"JOIN", "INNER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL", "USING", "ON",
"TABLE", "VIEW", "INDEX", "PRIMARY", "KEY", "FOREIGN", "REFERENCES", "NOT",
"NULL", "UNIQUE", "CHECK", "DEFAULT", "IF", "ELSE", "CASE", "WHEN", "THEN",
"END", "LOOP", "FOR", "WHILE", "CREATE", "ALTER", "DROP", "TRUNCATE", "COMMENT",
"EXISTS", "IN", "IS", "LIKE", "ILIKE", "SIMILAR", "BETWEEN", "AND", "OR", "ANY",
"ALL", "SOME", "FETCH", "NEXT", "ONLY", "ASC", "DESC", "GRANT", "REVOKE", "ROLE",
"USER", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP",
}
# 动态获取子查询
2 months ago
if original_table in ctrSqlDict and alias_name and alias_name.upper().split()[0] not in sql_keywords:
# 使用 ctrSqlDict 中的子查询替换表名
replaced = f"{keyword} ({ctrSqlDict[original_table]}) {alias_part}"
else:
# 默认处理逻辑:判断 alias 是否为关键字
11 months ago
if alias_name and alias_name.upper().split()[0] not in sql_keywords:
replaced = f"{keyword} ({subquery}) {alias_part}"
table_alias_map[original_table] = alias_name
11 months ago
else:
alias = original_table.split('.')[-1]
replaced = f"{keyword} ({subquery}) AS {alias}{alias_part or ''}"
11 months ago
table_alias_map[original_table] = alias
11 months ago
return replaced
oldStrSql = re.sub(from_join_pattern, from_join_replace, oldStrSql, flags=re.IGNORECASE)
# 2️⃣ 替换字段引用 table_name.column → alias.column
column_ref_pattern = re.escape(table_name) + r'\.(\w+)'
def column_replace(match):
col = match.group(1)
alias = table_alias_map.get(table_name, table_name.split('.')[-1])
return f"{alias}.{col}"
oldStrSql = re.sub(column_ref_pattern, column_replace, oldStrSql, flags=re.IGNORECASE)
return oldStrSql
async def get_data_source_tree(request: Request, current_user: MetaSecurityApiModel):
url = f'{AppConfig.ds_server_url}/dolphinscheduler/datasources/withpwdlist?pageNo=1&pageSize=100'
headers = {'dashUserName': current_user.username ,'dashPassword': current_user.password,}
response = requests.get(url, headers=headers, verify=False)
if response.reason == 'OK':
response_text = response.text
data = json.loads(response_text)
total_list = data["data"]["totalList"]
# 解析 connectionParams 字符串为字典
for item in total_list:
if item["name"]==current_user.dbRCode:
item["connectionParams"] = json.loads(item["connectionParams"])
return item
raise Exception(f'根据数据源ID:{current_user.dbRCode}获取数据源信息失败,状态: {response.reason}')
else:
raise Exception(f'根据数据源ID:{current_user.dbRCode}获取数据源信息失败,状态: {response.reason}')
11 months ago
async def test_connection(db_content):
try:
# 尝试执行一个简单的查询来测试连接
async with db_content.connect() as connection:
# 这里执行一个简单的查询,例如“SELECT 1”
await connection.scalar("SELECT 1")
except Exception as e:
10 months ago
raise Exception("数据源连接失败") from e
def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> str:
"""
生成带分页的 SQL 语句
:param page_object: 包含分页参数的对象
:param db_type: 数据库类型大写字符串
:return: 带分页的 SQL 语句
"""
page_num = page_object.pageNum or 1 # 当前页码,默认为 1
page_size = page_object.pageSize or 10 # 每页大小,默认为 10
offset = (page_num - 1) * page_size # 计算偏移量(跳过的行数)
oldStrSql = page_object.sqlStr # 获取原始 SQL 语句
db_type = db_type.upper() # 确保数据库类型为大写
if db_type == "MYSQL" or db_type == "POSTGRESQL":
newStrSql = f"{oldStrSql} LIMIT {page_size} OFFSET {offset}"
elif db_type == "SQLSERVER":
newStrSql = f"{oldStrSql} ORDER BY id OFFSET {offset} ROWS FETCH NEXT {page_size} ROWS ONLY"
elif db_type == "ORACLE":
newStrSql = f"""
SELECT * FROM (
SELECT a.*, ROWNUM rnum FROM (
{oldStrSql} ORDER BY id
) a WHERE ROWNUM <= {offset + page_size}
) WHERE rnum > {offset}
"""
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
2 months ago
return newStrSql
def _extract_identifiers(token):
"""
Identifier IdentifierList 中抽取 (schema, table)
返回格式为 'schema.table'如果有 schema否则 None
"""
if isinstance(token, Identifier):
real_name = token.get_real_name() # table
parent_name = token.get_parent_name() # schema if exists
if real_name and parent_name:
return f"{parent_name}.{real_name}"
# 处理像 schema.table AS alias 这种形式
# token.get_name() 返回 alias 或 table,根据需要可扩展
return None
async def get_tables_from_sql(sql_query: str):
"""
使用 sqlparse 解析 SQL 并返回 schema.table 列表去重
支持嵌套子查询函数别名JOININTOUPDATE
只返回包含 schema 的标识符即有点号的
"""
parsed = sqlparse.parse(sql_query)
tables = set()
for stmt in parsed:
# 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句
for token in stmt.tokens:
# 忽略函数、子查询整体(它们会在自己的 stmt 中被处理)
# 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList
if token.is_group:
# 递归遍历 group 内的 token
for t in token.flatten():
# 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...))
# 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard)
# 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens)
pass
# 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList
for token in stmt.tokens:
if token.ttype is DML and token.normalized.upper() in ("UPDATE",):
# UPDATE table_name ...
# 下一个非空白 token 往往是 Identifier
nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True)
if nxt:
name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None
if name:
tables.add(name)
# 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况
# 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier
idx = 0
tokens = list(stmt.tokens)
while idx < len(tokens):
t = tokens[idx]
if t.is_whitespace:
idx += 1
continue
if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"):
# 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询)
nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True)
if nxt:
tok = nxt[1]
# 如果是 parenthesis -> 子查询,跳过
if tok.is_group and isinstance(tok, Function):
# 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过
pass
else:
# 处理 Identifier 或 IdentifierList
if isinstance(tok, Identifier):
name = _extract_identifiers(tok)
if name:
tables.add(name)
elif isinstance(tok, IdentifierList):
for ident in tok.get_identifiers():
name = _extract_identifiers(ident)
if name:
tables.add(name)
else:
# 可能是直接的 Name token 'schema.table'(未被识别为 Identifier)
txt = tok.value
if "." in txt:
parts = txt.strip().strip('`"\'').split(".")
if len(parts) == 2:
tables.add(f"{parts[0]}.{parts[1]}")
idx += 1
continue
idx += 1
if not tables:
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构")
return sorted(tables)