Browse Source

style: 使用ruff格式化登录模块,优化导入

master
insistence 7 months ago
parent
commit
bd242c95f7
  1. 99
      ruoyi-fastapi-backend/module_admin/controller/login_controller.py
  2. 16
      ruoyi-fastapi-backend/module_admin/dao/login_dao.py
  3. 14
      ruoyi-fastapi-backend/module_admin/entity/vo/login_vo.py
  4. 247
      ruoyi-fastapi-backend/module_admin/service/login_service.py

99
ruoyi-fastapi-backend/module_admin/controller/login_controller.py

@ -1,79 +1,104 @@
from fastapi import APIRouter import uuid
from module_admin.service.login_service import * from datetime import datetime, timedelta
from module_admin.entity.vo.login_vo import * from fastapi import APIRouter, Depends, Request
from module_admin.dao.login_dao import * from jose import jwt
from module_admin.annotation.log_annotation import log_decorator from sqlalchemy.ext.asyncio import AsyncSession
from config.env import JwtConfig, RedisInitKeyConfig from typing import Optional
from config.enums import BusinessType from config.enums import BusinessType
from config.env import AppConfig, JwtConfig, RedisInitKeyConfig
from config.get_db import get_db
from module_admin.annotation.log_annotation import log_decorator
from module_admin.entity.vo.common_vo import CrudResponseModel
from module_admin.entity.vo.login_vo import UserLogin, UserRegister, Token
from module_admin.entity.vo.user_vo import CurrentUserModel, EditUserModel
from module_admin.service.login_service import CustomOAuth2PasswordRequestForm, LoginService, oauth2_scheme
from module_admin.service.user_service import UserService
from utils.log_util import logger
from utils.response_util import ResponseUtil from utils.response_util import ResponseUtil
from utils.log_util import *
from datetime import timedelta
loginController = APIRouter() loginController = APIRouter()
@loginController.post("/login", response_model=Token) @loginController.post('/login', response_model=Token)
@log_decorator(title='用户登录', business_type=BusinessType.OTHER, log_type='login') @log_decorator(title='用户登录', business_type=BusinessType.OTHER, log_type='login')
async def login(request: Request, form_data: CustomOAuth2PasswordRequestForm = Depends(), query_db: AsyncSession = Depends(get_db)): async def login(
captcha_enabled = True if await request.app.state.redis.get(f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.captchaEnabled") == 'true' else False request: Request, form_data: CustomOAuth2PasswordRequestForm = Depends(), query_db: AsyncSession = Depends(get_db)
):
captcha_enabled = (
True
if await request.app.state.redis.get(f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.captchaEnabled")
== 'true'
else False
)
user = UserLogin( user = UserLogin(
userName=form_data.username, userName=form_data.username,
password=form_data.password, password=form_data.password,
code=form_data.code, code=form_data.code,
uuid=form_data.uuid, uuid=form_data.uuid,
loginInfo=form_data.login_info, loginInfo=form_data.login_info,
captchaEnabled=captcha_enabled captchaEnabled=captcha_enabled,
) )
result = await LoginService.authenticate_user(request, query_db, user) result = await LoginService.authenticate_user(request, query_db, user)
access_token_expires = timedelta(minutes=JwtConfig.jwt_expire_minutes) access_token_expires = timedelta(minutes=JwtConfig.jwt_expire_minutes)
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
access_token = await LoginService.create_access_token( access_token = await LoginService.create_access_token(
data={ data={
"user_id": str(result[0].user_id), 'user_id': str(result[0].user_id),
"user_name": result[0].user_name, 'user_name': result[0].user_name,
"dept_name": result[1].dept_name if result[1] else None, 'dept_name': result[1].dept_name if result[1] else None,
"session_id": session_id, 'session_id': session_id,
"login_info": user.login_info 'login_info': user.login_info,
}, },
expires_delta=access_token_expires expires_delta=access_token_expires,
) )
if AppConfig.app_same_time_login: if AppConfig.app_same_time_login:
await request.app.state.redis.set(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}", access_token, await request.app.state.redis.set(
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes)) f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}",
access_token,
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes),
)
else: else:
# 此方法可实现同一账号同一时间只能登录一次 # 此方法可实现同一账号同一时间只能登录一次
await request.app.state.redis.set(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{result[0].user_id}", access_token, await request.app.state.redis.set(
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes)) f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{result[0].user_id}",
await UserService.edit_user_services(query_db, EditUserModel(userId=result[0].user_id, loginDate=datetime.now(), type='status')) access_token,
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes),
)
await UserService.edit_user_services(
query_db, EditUserModel(userId=result[0].user_id, loginDate=datetime.now(), type='status')
)
logger.info('登录成功') logger.info('登录成功')
# 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug # 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug
request_from_swagger = request.headers.get('referer').endswith('docs') if request.headers.get('referer') else False request_from_swagger = request.headers.get('referer').endswith('docs') if request.headers.get('referer') else False
request_from_redoc = request.headers.get('referer').endswith('redoc') if request.headers.get('referer') else False request_from_redoc = request.headers.get('referer').endswith('redoc') if request.headers.get('referer') else False
if request_from_swagger or request_from_redoc: if request_from_swagger or request_from_redoc:
return {'access_token': access_token, 'token_type': 'Bearer'} return {'access_token': access_token, 'token_type': 'Bearer'}
return ResponseUtil.success( return ResponseUtil.success(msg='登录成功', dict_content={'token': access_token})
msg='登录成功',
dict_content={'token': access_token}
)
@loginController.get("/getInfo", response_model=CurrentUserModel) @loginController.get('/getInfo', response_model=CurrentUserModel)
async def get_login_user_info(request: Request, current_user: CurrentUserModel = Depends(LoginService.get_current_user)): async def get_login_user_info(
request: Request, current_user: CurrentUserModel = Depends(LoginService.get_current_user)
):
logger.info('获取成功') logger.info('获取成功')
return ResponseUtil.success(model_content=current_user) return ResponseUtil.success(model_content=current_user)
@loginController.get("/getRouters") @loginController.get('/getRouters')
async def get_login_user_routers(request: Request, current_user: CurrentUserModel = Depends(LoginService.get_current_user), query_db: AsyncSession = Depends(get_db)): async def get_login_user_routers(
request: Request,
current_user: CurrentUserModel = Depends(LoginService.get_current_user),
query_db: AsyncSession = Depends(get_db),
):
logger.info('获取成功') logger.info('获取成功')
user_routers = await LoginService.get_current_user_routers(current_user.user.user_id, query_db) user_routers = await LoginService.get_current_user_routers(current_user.user.user_id, query_db)
return ResponseUtil.success(data=user_routers) return ResponseUtil.success(data=user_routers)
@loginController.post("/register", response_model=CrudResponseModel) @loginController.post('/register', response_model=CrudResponseModel)
async def register_user(request: Request, user_register: UserRegister, query_db: AsyncSession = Depends(get_db)): async def register_user(request: Request, user_register: UserRegister, query_db: AsyncSession = Depends(get_db)):
user_register_result = await LoginService.register_user_services(request, query_db, user_register) user_register_result = await LoginService.register_user_services(request, query_db, user_register)
logger.info(user_register_result.message) logger.info(user_register_result.message)
@ -111,11 +136,13 @@ async def register_user(request: Request, user_register: UserRegister, query_db:
# return ResponseUtil.error(msg=str(e)) # return ResponseUtil.error(msg=str(e))
@loginController.post("/logout") @loginController.post('/logout')
async def logout(request: Request, token: Optional[str] = Depends(oauth2_scheme)): async def logout(request: Request, token: Optional[str] = Depends(oauth2_scheme)):
payload = jwt.decode(token, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm], options={'verify_exp': False}) payload = jwt.decode(
session_id: str = payload.get("session_id") token, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm], options={'verify_exp': False}
)
session_id: str = payload.get('session_id')
await LoginService.logout_services(request, session_id) await LoginService.logout_services(request, session_id)
logger.info('退出成功') logger.info('退出成功')
return ResponseUtil.success(msg="退出成功") return ResponseUtil.success(msg='退出成功')

16
ruoyi-fastapi-backend/module_admin/dao/login_dao.py

@ -1,7 +1,7 @@
from sqlalchemy import select, and_ from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from module_admin.entity.do.user_do import SysUser
from module_admin.entity.do.dept_do import SysDept from module_admin.entity.do.dept_do import SysDept
from module_admin.entity.do.user_do import SysUser
async def login_by_account(db: AsyncSession, user_name: str): async def login_by_account(db: AsyncSession, user_name: str):
@ -11,11 +11,17 @@ async def login_by_account(db: AsyncSession, user_name: str):
:param user_name: 用户名 :param user_name: 用户名
:return: 用户对象 :return: 用户对象
""" """
user = (await db.execute( user = (
await db.execute(
select(SysUser, SysDept) select(SysUser, SysDept)
.where(SysUser.user_name == user_name, SysUser.del_flag == '0') .where(SysUser.user_name == user_name, SysUser.del_flag == '0')
.join(SysDept, and_(SysUser.dept_id == SysDept.dept_id, SysDept.status == '0', SysDept.del_flag == '0'), isouter=True) .join(
SysDept,
and_(SysUser.dept_id == SysDept.dept_id, SysDept.status == '0', SysDept.del_flag == '0'),
isouter=True,
)
.distinct() .distinct()
)).first() )
).first()
return user return user

14
ruoyi-fastapi-backend/module_admin/entity/vo/login_vo.py

@ -1,7 +1,7 @@
import re import re
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.alias_generators import to_camel from pydantic.alias_generators import to_camel
from typing import Optional, List, Union from typing import List, Optional, Union
from exceptions.exception import ModelValidatorException from exceptions.exception import ModelValidatorException
from module_admin.entity.vo.menu_vo import MenuModel from module_admin.entity.vo.menu_vo import MenuModel
@ -28,11 +28,11 @@ class UserRegister(BaseModel):
@model_validator(mode='after') @model_validator(mode='after')
def check_password(self) -> 'UserRegister': def check_password(self) -> 'UserRegister':
pattern = r'''^[^<>"'|\\]+$''' pattern = r"""^[^<>"'|\\]+$"""
if self.password is None or re.match(pattern, self.password): if self.password is None or re.match(pattern, self.password):
return self return self
else: else:
raise ModelValidatorException(message="密码不能包含非法字符:< > \" ' \\ |") raise ModelValidatorException(message='密码不能包含非法字符:< > " \' \\ |')
class Token(BaseModel): class Token(BaseModel):
@ -75,9 +75,13 @@ class RouterModel(BaseModel):
name: Optional[str] = Field(default=None, description='路由名称') name: Optional[str] = Field(default=None, description='路由名称')
path: Optional[str] = Field(default=None, description='路由地址') path: Optional[str] = Field(default=None, description='路由地址')
hidden: Optional[bool] = Field(default=None, description='是否隐藏路由,当设置 true 的时候该路由不会再侧边栏出现') hidden: Optional[bool] = Field(default=None, description='是否隐藏路由,当设置 true 的时候该路由不会再侧边栏出现')
redirect: Optional[str] = Field(default=None, description='重定向地址,当设置 noRedirect 的时候该路由在面包屑导航中不可被点击') redirect: Optional[str] = Field(
default=None, description='重定向地址,当设置 noRedirect 的时候该路由在面包屑导航中不可被点击'
)
component: Optional[str] = Field(default=None, description='组件地址') component: Optional[str] = Field(default=None, description='组件地址')
query: Optional[str] = Field(default=None, description='路由参数:如 {"id": 1, "name": "ry"}') query: Optional[str] = Field(default=None, description='路由参数:如 {"id": 1, "name": "ry"}')
always_show: Optional[bool] = Field(default=None, description='当你一个路由下面的children声明的路由大于1个时,自动会变成嵌套的模式--如组件页面') always_show: Optional[bool] = Field(
default=None, description='当你一个路由下面的children声明的路由大于1个时,自动会变成嵌套的模式--如组件页面'
)
meta: Optional[MetaModel] = Field(default=None, description='其他元素') meta: Optional[MetaModel] = Field(default=None, description='其他元素')
children: Optional[Union[List['RouterModel'], None]] = Field(default=None, description='子路由') children: Optional[Union[List['RouterModel'], None]] = Field(default=None, description='子路由')

247
ruoyi-fastapi-backend/module_admin/service/login_service.py

@ -1,24 +1,28 @@
from fastapi import Request, Form
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
import random import random
import uuid import uuid
from datetime import timedelta from datetime import datetime, timedelta
from module_admin.service.user_service import * from fastapi import Depends, Form, Request
from module_admin.entity.vo.login_vo import * from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from module_admin.entity.vo.common_vo import CrudResponseModel from jose import JWTError, jwt
from module_admin.dao.login_dao import * from sqlalchemy.ext.asyncio import AsyncSession
from exceptions.exception import LoginException, AuthException, ServiceException from typing import Dict, List, Optional, Union
from config.constant import CommonConstant, MenuConstant from config.constant import CommonConstant, MenuConstant
from config.env import AppConfig, JwtConfig, RedisInitKeyConfig from config.env import AppConfig, JwtConfig, RedisInitKeyConfig
from config.get_db import get_db from config.get_db import get_db
from exceptions.exception import LoginException, AuthException, ServiceException
from module_admin.dao.login_dao import login_by_account
from module_admin.dao.user_dao import UserDao
from module_admin.entity.do.menu_do import SysMenu
from module_admin.entity.vo.common_vo import CrudResponseModel
from module_admin.entity.vo.login_vo import MenuTreeModel, MetaModel, RouterModel, SmsCode, UserLogin, UserRegister
from module_admin.entity.vo.user_vo import AddUserModel, CurrentUserModel, ResetUserModel, TokenData, UserInfoModel
from module_admin.service.user_service import UserService
from utils.common_util import CamelCaseUtil from utils.common_util import CamelCaseUtil
from utils.pwd_util import * from utils.log_util import logger
from utils.response_util import * from utils.message_util import message_service
from utils.message_util import * from utils.pwd_util import PwdUtil
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl='login')
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm): class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
@ -28,18 +32,24 @@ class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
def __init__( def __init__(
self, self,
grant_type: str = Form(default=None, regex="password"), grant_type: str = Form(default=None, regex='password'),
username: str = Form(), username: str = Form(),
password: str = Form(), password: str = Form(),
scope: str = Form(default=""), scope: str = Form(default=''),
client_id: Optional[str] = Form(default=None), client_id: Optional[str] = Form(default=None),
client_secret: Optional[str] = Form(default=None), client_secret: Optional[str] = Form(default=None),
code: Optional[str] = Form(default=""), code: Optional[str] = Form(default=''),
uuid: Optional[str] = Form(default=""), uuid: Optional[str] = Form(default=''),
login_info: Optional[Dict[str, str]] = Form(default=None) login_info: Optional[Dict[str, str]] = Form(default=None),
): ):
super().__init__(grant_type=grant_type, username=username, password=password, super().__init__(
scope=scope, client_id=client_id, client_secret=client_secret) grant_type=grant_type,
username=username,
password=password,
scope=scope,
client_id=client_id,
client_secret=client_secret,
)
self.code = code self.code = code
self.uuid = uuid self.uuid = uuid
self.login_info = login_info self.login_info = login_info
@ -61,47 +71,61 @@ class LoginService:
""" """
await cls.__check_login_ip(request) await cls.__check_login_ip(request)
account_lock = await request.app.state.redis.get( account_lock = await request.app.state.redis.get(
f"{RedisInitKeyConfig.ACCOUNT_LOCK.get('key')}:{login_user.user_name}") f"{RedisInitKeyConfig.ACCOUNT_LOCK.get('key')}:{login_user.user_name}"
)
if login_user.user_name == account_lock: if login_user.user_name == account_lock:
logger.warning("账号已锁定,请稍后再试") logger.warning('账号已锁定,请稍后再试')
raise LoginException(data="", message="账号已锁定,请稍后再试") raise LoginException(data='', message='账号已锁定,请稍后再试')
# 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug # 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug
request_from_swagger = request.headers.get('referer').endswith('docs') if request.headers.get('referer') else False request_from_swagger = (
request_from_redoc = request.headers.get('referer').endswith('redoc') if request.headers.get('referer') else False request.headers.get('referer').endswith('docs') if request.headers.get('referer') else False
)
request_from_redoc = (
request.headers.get('referer').endswith('redoc') if request.headers.get('referer') else False
)
# 判断是否开启验证码,开启则验证,否则不验证(dev模式下来自API文档的登录请求不检验) # 判断是否开启验证码,开启则验证,否则不验证(dev模式下来自API文档的登录请求不检验)
if not login_user.captcha_enabled or ((request_from_swagger or request_from_redoc) and AppConfig.app_env == 'dev'): if not login_user.captcha_enabled or (
(request_from_swagger or request_from_redoc) and AppConfig.app_env == 'dev'
):
pass pass
else: else:
await cls.__check_login_captcha(request, login_user) await cls.__check_login_captcha(request, login_user)
user = await login_by_account(query_db, login_user.user_name) user = await login_by_account(query_db, login_user.user_name)
if not user: if not user:
logger.warning("用户不存在") logger.warning('用户不存在')
raise LoginException(data="", message="用户不存在") raise LoginException(data='', message='用户不存在')
if not PwdUtil.verify_password(login_user.password, user[0].password): if not PwdUtil.verify_password(login_user.password, user[0].password):
cache_password_error_count = await request.app.state.redis.get( cache_password_error_count = await request.app.state.redis.get(
f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}") f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}"
)
password_error_counted = 0 password_error_counted = 0
if cache_password_error_count: if cache_password_error_count:
password_error_counted = cache_password_error_count password_error_counted = cache_password_error_count
password_error_count = int(password_error_counted) + 1 password_error_count = int(password_error_counted) + 1
await request.app.state.redis.set( await request.app.state.redis.set(
f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}", password_error_count, f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}",
ex=timedelta(minutes=10)) password_error_count,
ex=timedelta(minutes=10),
)
if password_error_count > 5: if password_error_count > 5:
await request.app.state.redis.delete( await request.app.state.redis.delete(
f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}") f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}"
)
await request.app.state.redis.set( await request.app.state.redis.set(
f"{RedisInitKeyConfig.ACCOUNT_LOCK.get('key')}:{login_user.user_name}", login_user.user_name, f"{RedisInitKeyConfig.ACCOUNT_LOCK.get('key')}:{login_user.user_name}",
ex=timedelta(minutes=10)) login_user.user_name,
logger.warning("10分钟内密码已输错超过5次,账号已锁定,请10分钟后再试") ex=timedelta(minutes=10),
raise LoginException(data="", message="10分钟内密码已输错超过5次,账号已锁定,请10分钟后再试") )
logger.warning("密码错误") logger.warning('10分钟内密码已输错超过5次,账号已锁定,请10分钟后再试')
raise LoginException(data="", message="密码错误") raise LoginException(data='', message='10分钟内密码已输错超过5次,账号已锁定,请10分钟后再试')
logger.warning('密码错误')
raise LoginException(data='', message='密码错误')
if user[0].status == '1': if user[0].status == '1':
logger.warning("用户已停用") logger.warning('用户已停用')
raise LoginException(data="", message="用户已停用") raise LoginException(data='', message='用户已停用')
await request.app.state.redis.delete( await request.app.state.redis.delete(
f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}") f"{RedisInitKeyConfig.PASSWORD_ERROR_COUNT.get('key')}:{login_user.user_name}"
)
return user return user
@classmethod @classmethod
@ -112,11 +136,12 @@ class LoginService:
:return: 校验结果 :return: 校验结果
""" """
black_ip_value = await request.app.state.redis.get( black_ip_value = await request.app.state.redis.get(
f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.login.blackIPList") f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.login.blackIPList"
)
black_ip_list = black_ip_value.split(',') if black_ip_value else [] black_ip_list = black_ip_value.split(',') if black_ip_value else []
if request.headers.get('X-Forwarded-For') in black_ip_list: if request.headers.get('X-Forwarded-For') in black_ip_list:
logger.warning("当前IP禁止登录") logger.warning('当前IP禁止登录')
raise LoginException(data="", message="当前IP禁止登录") raise LoginException(data='', message='当前IP禁止登录')
return True return True
@classmethod @classmethod
@ -128,13 +153,14 @@ class LoginService:
:return: 校验结果 :return: 校验结果
""" """
captcha_value = await request.app.state.redis.get( captcha_value = await request.app.state.redis.get(
f"{RedisInitKeyConfig.CAPTCHA_CODES.get('key')}:{login_user.uuid}") f"{RedisInitKeyConfig.CAPTCHA_CODES.get('key')}:{login_user.uuid}"
)
if not captcha_value: if not captcha_value:
logger.warning("验证码已失效") logger.warning('验证码已失效')
raise LoginException(data="", message="验证码已失效") raise LoginException(data='', message='验证码已失效')
if login_user.code != str(captcha_value): if login_user.code != str(captcha_value):
logger.warning("验证码错误") logger.warning('验证码错误')
raise LoginException(data="", message="验证码错误") raise LoginException(data='', message='验证码错误')
return True return True
@classmethod @classmethod
@ -150,13 +176,14 @@ class LoginService:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta(minutes=30) expire = datetime.utcnow() + timedelta(minutes=30)
to_encode.update({"exp": expire}) to_encode.update({'exp': expire})
encoded_jwt = jwt.encode(to_encode, JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm) encoded_jwt = jwt.encode(to_encode, JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm)
return encoded_jwt return encoded_jwt
@classmethod @classmethod
async def get_current_user(cls, request: Request = Request, token: str = Depends(oauth2_scheme), async def get_current_user(
query_db: AsyncSession = Depends(get_db)): cls, request: Request = Request, token: str = Depends(oauth2_scheme), query_db: AsyncSession = Depends(get_db)
):
""" """
根据token获取当前用户信息 根据token获取当前用户信息
:param request: Request对象 :param request: Request对象
@ -172,31 +199,41 @@ class LoginService:
if token.startswith('Bearer'): if token.startswith('Bearer'):
token = token.split(' ')[1] token = token.split(' ')[1]
payload = jwt.decode(token, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm]) payload = jwt.decode(token, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
user_id: str = payload.get("user_id") user_id: str = payload.get('user_id')
session_id: str = payload.get("session_id") session_id: str = payload.get('session_id')
if user_id is None: if user_id is None:
logger.warning("用户token不合法") logger.warning('用户token不合法')
raise AuthException(data="", message="用户token不合法") raise AuthException(data='', message='用户token不合法')
token_data = TokenData(user_id=int(user_id)) token_data = TokenData(user_id=int(user_id))
except JWTError: except JWTError:
logger.warning("用户token已失效,请重新登录") logger.warning('用户token已失效,请重新登录')
raise AuthException(data="", message="用户token已失效,请重新登录") raise AuthException(data='', message='用户token已失效,请重新登录')
query_user = await UserDao.get_user_by_id(query_db, user_id=token_data.user_id) query_user = await UserDao.get_user_by_id(query_db, user_id=token_data.user_id)
if query_user.get('user_basic_info') is None: if query_user.get('user_basic_info') is None:
logger.warning("用户token不合法") logger.warning('用户token不合法')
raise AuthException(data="", message="用户token不合法") raise AuthException(data='', message='用户token不合法')
if AppConfig.app_same_time_login: if AppConfig.app_same_time_login:
redis_token = await request.app.state.redis.get(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}") redis_token = await request.app.state.redis.get(
f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}"
)
else: else:
# 此方法可实现同一账号同一时间只能登录一次 # 此方法可实现同一账号同一时间只能登录一次
redis_token = await request.app.state.redis.get(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{query_user.get('user_basic_info').user_id}") redis_token = await request.app.state.redis.get(
f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{query_user.get('user_basic_info').user_id}"
)
if token == redis_token: if token == redis_token:
if AppConfig.app_same_time_login: if AppConfig.app_same_time_login:
await request.app.state.redis.set(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}", redis_token, await request.app.state.redis.set(
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes)) f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{session_id}",
redis_token,
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes),
)
else: else:
await request.app.state.redis.set(f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{query_user.get('user_basic_info').user_id}", redis_token, await request.app.state.redis.set(
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes)) f"{RedisInitKeyConfig.ACCESS_TOKEN.get('key')}:{query_user.get('user_basic_info').user_id}",
redis_token,
ex=timedelta(minutes=JwtConfig.jwt_redis_expire_minutes),
)
role_id_list = [item.role_id for item in query_user.get('user_role_info')] role_id_list = [item.role_id for item in query_user.get('user_role_info')]
if 1 in role_id_list: if 1 in role_id_list:
@ -215,13 +252,13 @@ class LoginService:
postIds=post_ids, postIds=post_ids,
roleIds=role_ids, roleIds=role_ids,
dept=CamelCaseUtil.transform_result(query_user.get('user_dept_info')), dept=CamelCaseUtil.transform_result(query_user.get('user_dept_info')),
role=CamelCaseUtil.transform_result(query_user.get('user_role_info')) role=CamelCaseUtil.transform_result(query_user.get('user_role_info')),
) ),
) )
return current_user return current_user
else: else:
logger.warning("用户token已失效,请重新登录") logger.warning('用户token已失效,请重新登录')
raise AuthException(data="", message="用户token已失效,请重新登录") raise AuthException(data='', message='用户token已失效,请重新登录')
@classmethod @classmethod
async def get_current_user_routers(cls, user_id: int, query_db: AsyncSession): async def get_current_user_routers(cls, user_id: int, query_db: AsyncSession):
@ -232,7 +269,14 @@ class LoginService:
:return: 当前用户路由信息对象 :return: 当前用户路由信息对象
""" """
query_user = await UserDao.get_user_by_id(query_db, user_id=user_id) query_user = await UserDao.get_user_by_id(query_db, user_id=user_id)
user_router_menu = sorted([row for row in query_user.get('user_menu_info') if row.menu_type in [MenuConstant.TYPE_DIR, MenuConstant.TYPE_MENU]], key=lambda x: x.order_num) user_router_menu = sorted(
[
row
for row in query_user.get('user_menu_info')
if row.menu_type in [MenuConstant.TYPE_DIR, MenuConstant.TYPE_MENU]
],
key=lambda x: x.order_num,
)
menus = cls.__generate_menus(0, user_router_menu) menus = cls.__generate_menus(0, user_router_menu)
user_router = cls.__generate_user_router_menu(menus) user_router = cls.__generate_user_router_menu(menus)
return [router.model_dump(exclude_unset=True, by_alias=True) for router in user_router] return [router.model_dump(exclude_unset=True, by_alias=True) for router in user_router]
@ -275,8 +319,8 @@ class LoginService:
title=permission.menu_name, title=permission.menu_name,
icon=permission.icon, icon=permission.icon,
noCache=True if permission.is_cache == 1 else False, noCache=True if permission.is_cache == 1 else False,
link=permission.path if RouterUtil.is_http(permission.path) else None link=permission.path if RouterUtil.is_http(permission.path) else None,
) ),
) )
c_menus = permission.children c_menus = permission.children
if c_menus and permission.menu_type == MenuConstant.TYPE_DIR: if c_menus and permission.menu_type == MenuConstant.TYPE_DIR:
@ -294,17 +338,14 @@ class LoginService:
title=permission.menu_name, title=permission.menu_name,
icon=permission.icon, icon=permission.icon,
noCache=True if permission.is_cache == 1 else False, noCache=True if permission.is_cache == 1 else False,
link=permission.path if RouterUtil.is_http(permission.path) else None link=permission.path if RouterUtil.is_http(permission.path) else None,
), ),
query=permission.query query=permission.query,
) )
children_list.append(children) children_list.append(children)
router.children = children_list router.children = children_list
elif permission.parent_id == 0 and RouterUtil.is_inner_link(permission): elif permission.parent_id == 0 and RouterUtil.is_inner_link(permission):
router.meta = MetaModel( router.meta = MetaModel(title=permission.menu_name, icon=permission.icon)
title=permission.menu_name,
icon=permission.icon
)
router.path = '/' router.path = '/'
children_list: List[RouterModel] = [] children_list: List[RouterModel] = []
router_path = RouterUtil.inner_link_replace_each(permission.path) router_path = RouterUtil.inner_link_replace_each(permission.path)
@ -315,8 +356,8 @@ class LoginService:
meta=MetaModel( meta=MetaModel(
title=permission.menu_name, title=permission.menu_name,
icon=permission.icon, icon=permission.icon,
link=permission.path if RouterUtil.is_http(permission.path) else None link=permission.path if RouterUtil.is_http(permission.path) else None,
) ),
) )
children_list.append(children) children_list.append(children)
router.children = children_list router.children = children_list
@ -334,15 +375,26 @@ class LoginService:
:param user_register: 注册用户对象 :param user_register: 注册用户对象
:return: 注册结果 :return: 注册结果
""" """
register_enabled = True if await request.app.state.redis.get( register_enabled = (
f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.registerUser") == 'true' else False True
captcha_enabled = True if await request.app.state.redis.get( if await request.app.state.redis.get(f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.registerUser")
f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.captchaEnabled") == 'true' else False == 'true'
else False
)
captcha_enabled = (
True
if await request.app.state.redis.get(
f"{RedisInitKeyConfig.SYS_CONFIG.get('key')}:sys.account.captchaEnabled"
)
== 'true'
else False
)
if user_register.password == user_register.confirm_password: if user_register.password == user_register.confirm_password:
if register_enabled: if register_enabled:
if captcha_enabled: if captcha_enabled:
captcha_value = await request.app.state.redis.get( captcha_value = await request.app.state.redis.get(
f"{RedisInitKeyConfig.CAPTCHA_CODES.get('key')}:{user_register.uuid}") f"{RedisInitKeyConfig.CAPTCHA_CODES.get('key')}:{user_register.uuid}"
)
if not captcha_value: if not captcha_value:
raise ServiceException(message='验证码已失效') raise ServiceException(message='验证码已失效')
elif user_register.code != str(captcha_value): elif user_register.code != str(captcha_value):
@ -350,7 +402,7 @@ class LoginService:
add_user = AddUserModel( add_user = AddUserModel(
userName=user_register.username, userName=user_register.username,
nickName=user_register.username, nickName=user_register.username,
password=PwdUtil.get_password_hash(user_register.password) password=PwdUtil.get_password_hash(user_register.password),
) )
result = await UserService.add_user_services(query_db, add_user) result = await UserService.add_user_services(query_db, add_user)
return result return result
@ -369,15 +421,17 @@ class LoginService:
:return: 短信验证码对象 :return: 短信验证码对象
""" """
redis_sms_result = await request.app.state.redis.get( redis_sms_result = await request.app.state.redis.get(
f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{user.session_id}") f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{user.session_id}"
)
if redis_sms_result: if redis_sms_result:
return SmsCode(**dict(is_success=False, sms_code='', session_id='', message='短信验证码仍在有效期内')) return SmsCode(**dict(is_success=False, sms_code='', session_id='', message='短信验证码仍在有效期内'))
is_user = await UserDao.get_user_by_name(query_db, user.user_name) is_user = await UserDao.get_user_by_name(query_db, user.user_name)
if is_user: if is_user:
sms_code = str(random.randint(100000, 999999)) sms_code = str(random.randint(100000, 999999))
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
await request.app.state.redis.set(f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{session_id}", sms_code, await request.app.state.redis.set(
ex=timedelta(minutes=2)) f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{session_id}", sms_code, ex=timedelta(minutes=2)
)
# 此处模拟调用短信服务 # 此处模拟调用短信服务
message_service(sms_code) message_service(sms_code)
@ -395,7 +449,8 @@ class LoginService:
:return: 重置结果 :return: 重置结果
""" """
redis_sms_result = await request.app.state.redis.get( redis_sms_result = await request.app.state.redis.get(
f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{forget_user.session_id}") f"{RedisInitKeyConfig.SMS_CODE.get('key')}:{forget_user.session_id}"
)
if forget_user.sms_code == redis_sms_result: if forget_user.sms_code == redis_sms_result:
forget_user.password = PwdUtil.get_password_hash(forget_user.password) forget_user.password = PwdUtil.get_password_hash(forget_user.password)
forget_user.user_id = (await UserDao.get_user_by_name(query_db, forget_user.user_name)).user_id forget_user.user_id = (await UserDao.get_user_by_name(query_db, forget_user.user_name)).user_id
@ -484,7 +539,9 @@ class RouterUtil:
:param menu: 菜单数对象 :param menu: 菜单数对象
:return: 是否为菜单内部跳转 :return: 是否为菜单内部跳转
""" """
return menu.parent_id == 0 and menu.menu_type == MenuConstant.TYPE_MENU and menu.is_frame == MenuConstant.NO_FRAME return (
menu.parent_id == 0 and menu.menu_type == MenuConstant.TYPE_MENU and menu.is_frame == MenuConstant.NO_FRAME
)
@classmethod @classmethod
def is_inner_link(cls, menu: MenuTreeModel): def is_inner_link(cls, menu: MenuTreeModel):
@ -520,8 +577,8 @@ class RouterUtil:
:param path: 内链域名 :param path: 内链域名
:return: 替换后的内链域名 :return: 替换后的内链域名
""" """
old_values = [CommonConstant.HTTP, CommonConstant.HTTPS, CommonConstant.WWW, ".", ":"] old_values = [CommonConstant.HTTP, CommonConstant.HTTPS, CommonConstant.WWW, '.', ':']
new_values = ["", "", "", "/", "/"] new_values = ['', '', '', '/', '/']
for old, new in zip(old_values, new_values): for old, new in zip(old_values, new_values):
path = path.replace(old, new) path = path.replace(old, new)
return path return path

Loading…
Cancel
Save