You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

154 lines
7.0 KiB

import datetime
import shutil
import os
import uuid
from fastapi import APIRouter, Request, UploadFile, Form
from fastapi import Depends, File
from config.get_db import get_db
from module_admin.service.login_service import get_current_user
from module_admin.service.aichat_service import *
from module_admin.entity.vo.aichat_vo import *
from module_admin.dao.aichat_dao import *
from utils.response_util import *
from utils.log_util import *
from utils.minio_util import *
from module_admin.aspect.interface_auth import CheckUserInterfaceAuth
from module_admin.annotation.log_annotation import log_decorator
from config.env import MinioConfig
from datetime import datetime
from utils.common_util import bytes2file_response
aichatController = APIRouter(dependencies=[Depends(get_current_user)])
@aichatController.get("/session/list/{session_id}", dependencies=[Depends(CheckUserInterfaceAuth('common'))])
async def get_chat_session_list(request: Request, session_id: str, query_db: Session = Depends(get_db),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
try:
ai_session_list_result = AiChatService.get_ai_session_list_services(query_db, session_id, current_user)
logger.info('获取成功')
return response_200(data=ai_session_list_result, message="获取成功")
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.get("/chat/list/{session_id}", dependencies=[Depends(CheckUserInterfaceAuth('common'))])
async def get_chat_list(request: Request, session_id: str, query_db: Session = Depends(get_db),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
try:
ai_chat_list_result = AiChatService.get_ai_chat_list_services(query_db, session_id, current_user)
logger.info('获取成功')
return response_200(data=ai_chat_list_result, message="获取成功")
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.post("/delete/session/{session_id}", response_model=CrudChatModel,
dependencies=[Depends(CheckUserInterfaceAuth('common'))])
async def delete_chat_session(request: Request, session_id: str, query_db: Session = Depends(get_db),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
try:
delete_chat_session_result = AiChatService.delete_chat_session(query_db, session_id, current_user)
logger.info(delete_chat_session_result.message)
return response_200(data=delete_chat_session_result, message=delete_chat_session_result.message)
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.post("/add", response_model=CrudChatModel, dependencies=[Depends(CheckUserInterfaceAuth('common'))])
async def add_chat(request: Request, add_chat: AiChatModel, query_db: Session = Depends(get_db),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
try:
operate_result = AiChatService.add_chat(query_db, add_chat, current_user)
if operate_result.is_success:
logger.info(operate_result.message)
return response_200(data=operate_result, message=operate_result.message)
else:
logger.warning(operate_result.message)
return response_400(data="", message=add_menu_result.message)
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.post("/update", response_model=CrudChatModel,
dependencies=[Depends(CheckUserInterfaceAuth('common'))])
async def update_chat(request: Request, update_chat: AiChatModel, query_db: Session = Depends(get_db),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
try:
operate_result = AiChatService.update_chat(query_db, update_chat)
if operate_result.is_success:
logger.info(operate_result.message)
return response_200(data=operate_result, message=operate_result.message)
else:
logger.warning(operate_result.message)
return response_400(data="", message=add_menu_result.message)
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.post("/upload", dependencies=[Depends(get_current_user), Depends(CheckUserInterfaceAuth('common'))])
async def upload_file(request: Request, session_id: str = Form(), file: UploadFile = File(...),
current_user: CurrentUserInfoServiceResponse = Depends(get_current_user)):
print(11111111111111111111111)
try:
file_extension = os.path.splitext(file.filename)[1] # 文件后缀
file_prefix = os.path.splitext(file.filename)[0] # 文件前缀
bucket = Bucket(
minio_address=MinioConfig.minio_address,
minio_admin=MinioConfig.minio_admin,
minio_password=MinioConfig.minio_password
)
file_id = datetime.now().strftime('%Y%m%d%H%M%S')
file_location = file_prefix + "_" + file_id + file_extension
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Remove the temporary file after upload (optional)
def remove_file(file_path):
os.remove(file_path)
# Upload to MinIO
bucket.upload_file_to_bucket(current_user.user.user_name, '/'+session_id+'/'+file_location, file_location)
# Clean up
remove_file(file_location)
return response_200(data={"is_success": True, "message": "上传成功", "file": file_location, "bucket": current_user.user.user_name},
message="上传成功")
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))
@aichatController.post("/file/download",
dependencies=[Depends(get_current_user), Depends(CheckUserInterfaceAuth('common'))])
async def download_file(request: Request, download_file: DownloadFile):
try:
file_name = download_file.file
bucket_name = download_file.bucket
session_id = download_file.session_id
bucket = Bucket(
minio_address=MinioConfig.minio_address,
minio_admin=MinioConfig.minio_admin,
minio_password=MinioConfig.minio_password
)
# Remove the temporary file after upload (optional)
def remove_file(file_path):
os.remove(file_path)
# Upload to MinIO
bucket.download_file_from_bucket(bucket_name, '/'+session_id+'/'+file_name, file_name)
with open(file_name, "rb") as buffer:
file_data = buffer.read()
# Clean up
remove_file(file_name)
return streaming_response_200(data=bytes2file_response(file_data))
except Exception as e:
logger.exception(e)
return response_500(data="", message=str(e))