瀏覽代碼

重构结构

root 6 月之前
父節點
當前提交
0b8c9327e8

文件差異過大導致無法顯示
+ 0 - 21
111.py


+ 8 - 8
app.py

@@ -102,7 +102,8 @@ model = AutoModel(model="E:\\yuyin_model\\Voice_translation", model_revision="v2
 
 
 # 后台接口
-
+# **已迁移
+# 知识问答文件解析
 @app.route('/embed', methods=['POST'])
 def route_embed():
     start_time = time.time()
@@ -124,7 +125,7 @@ def route_embed():
     return jsonify({"error": "File embedded unsuccessfully"}), 400
 
 
-
+# **已迁移
 def route_query(msg):
     response = query(msg)
     # print(response)
@@ -136,7 +137,7 @@ def route_query(msg):
     #     return resObj
     # return {"error": "Something went wrong"}, 400
     return response
-
+# **已迁移
 @app.route('/delete', methods=['DELETE'])
 def route_delete():
     db = get_vector_db()
@@ -144,16 +145,13 @@ def route_delete():
 
     return jsonify({"message": "Collection deleted successfully"}), 200
 
+# **已迁移
 @app.route("/")
 def home():
     return render_template('index.html')
 
 # 后台接口
 
-
-
-
-
 #定义需要替换的词
 target_word = "抱坡"
 target_word_pinyin = lazy_pinyin(target_word)
@@ -170,6 +168,7 @@ def replace_word(text,target_word):
             text = text.replace(word,target_word)
     return text
 
+# *已迁移
 # 文件上传
 @app.route('/upload', methods=['POST'])
 def upload_file():
@@ -319,6 +318,7 @@ def update_chat_history_simple(user_message):
     # 返回机器人的回复
     return bot_message
 
+# *已迁移
 @app.route('/closeMsg', methods=['DELETE'])
 def delMsg():
     global chat_history
@@ -328,7 +328,7 @@ def delMsg():
                     "chat_history": chat_history,
                     })
 
-
+# *已迁移
 @app.route('/msg', methods=['POST'])
 def inputMsg():
     # 从请求中获取JSON数据

二進制
app/__pycache__/routes.cpython-310.pyc


+ 0 - 0
app/common/__init__.py


+ 40 - 0
app/common/res.py

@@ -0,0 +1,40 @@
+# 返回正确信息
+
+def res_success(json_res, type, msg):
+    resObj = {}
+    resObj["code"] = 200
+    resObj["data"] = json_res
+    if type != "":
+        resObj["type"] = type
+    if msg != "":
+        resObj["msg"] = msg
+    return resObj
+
+# 返回错误信息
+
+
+def res_error(json_res, type, msg):
+    resObj = {}
+    resObj["code"] = 500
+    resObj["data"] = json_res
+    if type != "":
+        resObj["type"] = type
+    if msg != "":
+        resObj["msg"] = msg
+    return resObj
+
+# #返回问答信息
+# def jsonResToDict_questions(json_res):
+#     resObj = {}
+#     resObj["data"] = json_res
+#     resObj["code"] = 200
+#     resObj["type"] = "answer"
+#     return resObj
+
+# # 返回错误信息
+# def jsonResToDict_wrong(json_res):
+#     resObj = {}
+#     resObj["data"] = json_res
+#     resObj["code"] = 500
+#     resObj["type"] = "selectLand"
+#     return resObj

+ 2 - 0
app/common/word.py

@@ -0,0 +1,2 @@
+#定义需要替换的词
+target_word = "抱坡"

+ 65 - 1
app/routes.py

@@ -3,20 +3,27 @@ from flask import Blueprint, render_template, request, jsonify
 
 
 from app.services.file_service import parse_file  # 导入 service 中的函数
+from app.services.chat_service import clear_chat_history, create_chat
+from app.services.embed_service import parse_file_to_embed
+from llm_model.get_vector_db import get_vector_db
 
 main_bp = Blueprint('main', __name__)
 
+# 主页
+
 
 @main_bp.route('/')
 def index():
     return render_template('index.html')
 
+# 测试接口
+
 
 @main_bp.route("/hello")
 def hello():
     return "Hello, World!"
 
-# 文件上传
+# 语音文件上传
 
 
 @main_bp.route('/upload', methods=['POST'])
@@ -31,3 +38,60 @@ def upload_file():
     # 保存文件并解析
     res_obj = parse_file(file)
     return jsonify(res_obj), 200
+
+# 关闭聊天记录
+
+
+@main_bp.route('/closeMsg', methods=['DELETE'])
+def del_msg():
+    return jsonify({"msg": "清除成功",
+                    "code": 200,
+                    "chat_history": clear_chat_history(),
+                    })
+
+# 聊天
+
+
+@main_bp.route('/msg', methods=['POST'])
+def input_msg():
+    # 从请求中获取JSON数据
+    data = request.get_json()
+
+    # 检查是否接收到数据
+    if not data:
+        return jsonify({"error": "No data received"}), 400
+
+    # 打印接收到的消息
+    print(data['msg'])
+    msg = data['msg']
+    type = data['type']
+
+    json_res = create_chat(msg, type)
+
+    # 返回响应
+    return jsonify(json_res)
+
+# 知识问答文件解析
+
+
+@main_bp.route('/embed', methods=['POST'])
+def route_embed():
+    if 'file' not in request.files:
+        return jsonify({"error": "No file part"}), 400
+
+    file = request.files['file']
+
+    if file.filename == '':
+        return jsonify({"error": "No selected file"}), 400
+
+    return parse_file_to_embed(file)
+
+# 删除向量库
+
+
+@main_bp.route('/embed', methods=['DELETE'])
+def route_delete():
+    db = get_vector_db()
+    db.delete_collection()
+
+    return jsonify({"message": "Collection deleted successfully"}), 200

二進制
app/services/__pycache__/__init__.cpython-310.pyc


二進制
app/services/__pycache__/file_service.cpython-310.pyc


文件差異過大導致無法顯示
+ 18 - 0
app/services/chat_service.py


+ 48 - 0
app/services/embed_service.py

@@ -0,0 +1,48 @@
+import time
+import os
+
+from llm_model.embed import embed
+from app.common.res import res_success, res_error
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.vectorstores.chroma import Chroma
+
+# 解析文件到向量
+
+
+def parse_file_to_embed(file):
+    start_time = time.time()
+    embedded = embed(file)
+    end_time = time.time()
+    print("Time taken for embedding: ", end_time - start_time)
+
+    if embedded:
+        return res_success(msg="File embedded successfully")
+    else:
+        return res_error(msg="File embedded unsuccessfully")
+
+# 删除向量
+
+
+def delete_embed():
+    db = get_vector_db()
+    db.delete_collection()
+
+    return res_success(msg="Collection deleted successfully")
+
+
+CHROMA_PATH = os.getenv('CHROMA_PATH', 'chroma')
+COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'siwei_ai')
+TEXT_EMBEDDING_MODEL = os.getenv('TEXT_EMBEDDING_MODEL', 'nomic-embed-text')
+
+
+def get_vector_db():
+    embedding = OllamaEmbeddings(
+        model=TEXT_EMBEDDING_MODEL, show_progress=True, num_gpu=0, num_thread=4)
+
+    db = Chroma(
+        collection_name=COLLECTION_NAME,
+        persist_directory=CHROMA_PATH,
+        embedding_function=embedding
+    )
+
+    return db

+ 2 - 19
app/services/file_service.py

@@ -1,11 +1,10 @@
 
 import os
 import uuid
-import re
-
-from pypinyin import lazy_pinyin
 from funasr import AutoModel
 
+from app.utils.pinyin_utils import replace_word
+
 target_word = "抱坡"
 # 模型1
 model = AutoModel(model="E:\\yuyin_model\\Voice_translation", model_revision="v2.0.4",
@@ -14,22 +13,6 @@ model = AutoModel(model="E:\\yuyin_model\\Voice_translation", model_revision="v2
                   use_cuda=True, use_fast=True,
                   )
 
-# 替换同音字
-
-
-def replace_word(text, target_word):
-    words = re.findall(r'\b\w+\b', text)
-    for word in words:
-        if is_same_pinyin(word, target_word):
-            text = text.replace(word, target_word)
-    return text
-
-# 判断拼音是否相同
-
-
-def is_same_pinyin(word1, word2):
-    return lazy_pinyin(word1) == lazy_pinyin(word2)
-
 
 def parse_file(file):
     # 文件保存路径

+ 0 - 0
app/utils/__init__.py


+ 5 - 0
app/utils/log_utils.py

@@ -0,0 +1,5 @@
+import logging
+
+# 配置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)

+ 18 - 0
app/utils/pinyin_utils.py

@@ -0,0 +1,18 @@
+import re
+from pypinyin import lazy_pinyin
+
+# 替换同音字
+
+
+def replace_word(text, target_word):
+    words = re.findall(r'\b\w+\b', text)
+    for word in words:
+        if is_same_pinyin(word, target_word):
+            text = text.replace(word, target_word)
+    return text
+
+# 判断拼音是否相同
+
+
+def is_same_pinyin(word1, word2):
+    return lazy_pinyin(word1) == lazy_pinyin(word2)

+ 48 - 0
llm_model/embed.py

@@ -0,0 +1,48 @@
+import os
+from datetime import datetime
+from werkzeug.utils import secure_filename
+from langchain_community.document_loaders import UnstructuredPDFLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from get_vector_db import get_vector_db
+
+TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
+
+# Function to check if the uploaded file is allowed (only PDF files)
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'pdf'}
+
+# Function to save the uploaded file to the temporary folder
+def save_file(file):
+    # Save the uploaded file with a secure filename and return the file path
+    ct = datetime.now()
+    ts = ct.timestamp()
+    filename = str(ts) + "_" + secure_filename(file.filename)
+    file_path = os.path.join(TEMP_FOLDER, filename)
+    file.save(file_path)
+
+    return file_path
+
+# Function to load and split the data from the PDF file
+def load_and_split_data(file_path):
+    # Load the PDF file and split the data into chunks
+    loader = UnstructuredPDFLoader(file_path=file_path)
+    data = loader.load()
+    text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
+    chunks = text_splitter.split_documents(data)
+
+    return chunks
+
+# Main function to handle the embedding process
+def embed(file):
+    # Check if the file is valid, save it, load and split the data, add to the database, and remove the temporary file
+    if file.filename != '' and file and allowed_file(file.filename):
+        file_path = save_file(file)
+        chunks = load_and_split_data(file_path)
+        db = get_vector_db()
+        db.add_documents(chunks)
+        db.persist()
+        os.remove(file_path)
+
+        return True
+
+    return False

+ 18 - 0
llm_model/get_vector_db.py

@@ -0,0 +1,18 @@
+import os
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.vectorstores.chroma import Chroma
+
+CHROMA_PATH = os.getenv('CHROMA_PATH', 'chroma')
+COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'siwei_ai')
+TEXT_EMBEDDING_MODEL = os.getenv('TEXT_EMBEDDING_MODEL', 'nomic-embed-text')
+
+def get_vector_db():
+    embedding = OllamaEmbeddings(model=TEXT_EMBEDDING_MODEL,show_progress=True,num_gpu=0,num_thread=4)
+
+    db = Chroma(
+        collection_name=COLLECTION_NAME,
+        persist_directory=CHROMA_PATH,
+        embedding_function=embedding
+    )
+
+    return db

+ 62 - 0
llm_model/query.py

@@ -0,0 +1,62 @@
+import os
+from langchain_community.chat_models import ChatOllama
+from langchain.prompts import ChatPromptTemplate, PromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain.retrievers.multi_query import MultiQueryRetriever
+from llm_model.get_vector_db import get_vector_db
+
+LLM_MODEL = os.getenv('LLM_MODEL', 'qwen2:7b')
+
+# Function to get the prompt templates for generating alternative questions and answering based on context
+def get_prompt():
+    QUERY_PROMPT = PromptTemplate(
+        input_variables=["question"],
+        template="""你是一名AI语言模型助理。你的任务是生成三个
+        从中检索相关文档的给定用户问题的不同版本
+        矢量数据库。通过对用户问题生成多个视角
+        目标是帮助用户克服基于距离的一些局限性
+        相似性搜索。请提供这些用换行符分隔的备选问题。
+        Original question: {question}""",
+    )
+
+    template = """仅根据以下上下文用中文回答问题:
+    {context},请严格以markdown格式输出并保障寄送格式正确无误,
+    Question: {question}
+    """
+    # Question: {question}
+
+
+    prompt = ChatPromptTemplate.from_template(template)
+    return QUERY_PROMPT, prompt
+
+# Main function to handle the query process
+def query(input):
+    if input:
+        # Initialize the language model with the specified model name
+        llm = ChatOllama(model=LLM_MODEL,keep_alive=-1,num_gpu=0)
+        # Get the vector database instance
+        db = get_vector_db()
+        # Get the prompt templates
+        QUERY_PROMPT, prompt = get_prompt()
+
+        # Set up the retriever to generate multiple queries using the language model and the query prompt
+        retriever = MultiQueryRetriever.from_llm(
+            db.as_retriever(),
+            llm,
+            prompt=QUERY_PROMPT
+        )
+
+        # Define the processing chain to retrieve context, generate the answer, and parse the output
+        chain = (
+            {"context": retriever, "question": RunnablePassthrough()}
+            | prompt
+            | llm
+            | StrOutputParser()
+        )
+
+        response = chain.invoke(input)
+
+        return response
+
+    return None

+ 0 - 2
test.py

@@ -1,2 +0,0 @@
-print("hello world")
-print("hello world")

+ 44 - 44
vocal.py

@@ -1,58 +1,58 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-import time
-import torch
+# from modelscope.pipelines import pipeline
+# from modelscope.utils.constant import Tasks
+# import time
+# import torch
 
-# print(torch.__version__) # 查看torch当前版本号
+# # print(torch.__version__) # 查看torch当前版本号
 
-# print(torch.version.cuda) # 编译当前版本的torch使用的cuda版本号
+# # print(torch.version.cuda) # 编译当前版本的torch使用的cuda版本号
 
-# print(torch.cuda.is_available()) # 查看当前cuda是否可用于当前版本的Torch,如果输出True,则表示可用
+# # print(torch.cuda.is_available()) # 查看当前cuda是否可用于当前版本的Torch,如果输出True,则表示可用
 
 
 
-def voice_text(input_video_path,model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'):
-    inference_pipeline = pipeline(
-    task=Tasks.auto_speech_recognition,
-    # model='iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
-    model=model, 
-    # model="model\punc_ct-transformer_cn-en-common-vocab471067-large",
-    model_revision="v2.0.4",
-    device='gpu')
-    
-    res = inference_pipeline(input_video_path)
-    # print(res)
-    texts = [item['text'] for item in res]
+# def voice_text(input_video_path,model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'):
+#     inference_pipeline = pipeline(
+#     task=Tasks.auto_speech_recognition,
+#     # model='iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+#     model=model,
+#     # model="model\punc_ct-transformer_cn-en-common-vocab471067-large",
+#     model_revision="v2.0.4",
+#     device='gpu')
 
-    # print(texts)
-    result = ' '.join(texts)
-    return result
+#     res = inference_pipeline(input_video_path)
+#     # print(res)
+#     texts = [item['text'] for item in res]
 
-if  __name__ == "__main__":
-    start_time = time.time()
-    inference_pipeline = pipeline(
-        task=Tasks.auto_speech_recognition,
-        # model='iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
-        model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', 
-        # model="model\punc_ct-transformer_cn-en-common-vocab471067-large",
-        model_revision="v2.0.4",
-        device='gpu')
+#     # print(texts)
+#     result = ' '.join(texts)
+#     return result
 
-    # rec_result = inference_pipeline('https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav')
+# if  __name__ == "__main__":
+#     start_time = time.time()
+#     inference_pipeline = pipeline(
+#         task=Tasks.auto_speech_recognition,
+#         # model='iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+#         model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+#         # model="model\punc_ct-transformer_cn-en-common-vocab471067-large",
+#         model_revision="v2.0.4",
+#         device='gpu')
 
-    # 替换为本地语音文件路径
-    local_audio_path = 'data/audio/5bf77846-0193-4f35-92f7-09ce51ee3793.mp3'
-    res = inference_pipeline(local_audio_path)
-    # print(res)
-    texts = [item['text'] for item in res]
+#     # rec_result = inference_pipeline('https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav')
 
-    # print(texts)
-    result = ' '.join(texts)
-    print(result)
+#     # 替换为本地语音文件路径
+#     local_audio_path = 'data/audio/5bf77846-0193-4f35-92f7-09ce51ee3793.mp3'
+#     res = inference_pipeline(local_audio_path)
+#     # print(res)
+#     texts = [item['text'] for item in res]
 
+#     # print(texts)
+#     result = ' '.join(texts)
+#     print(result)
 
-    end_time = time.time()
-    # 计算时间差
-    elapsed_time = end_time - start_time
 
-    print(f"耗时: {elapsed_time} 秒")
+#     end_time = time.time()
+#     # 计算时间差
+#     elapsed_time = end_time - start_time
+
+#     print(f"耗时: {elapsed_time} 秒")

+ 39 - 39
voice_translation_test.py

@@ -1,52 +1,52 @@
-from funasr import AutoModel
-import time
+# from funasr import AutoModel
+# import time
 
-def vocal_text(input_video_path):
-    model = AutoModel(model="./Voice_translation", model_revision="v2.0.4",
-                    vad_model="./Endpoint_detection", vad_model_revision="v2.0.4",
-                    punc_model="./Ct_punc", punc_model_revision="v2.0.4",
-                    use_cuda=True,use_fast = True,
-                    )
-    res = model.generate(input_video_path, 
-                batch_size_s=30, 
-                hotword='test')
+# def vocal_text(input_video_path):
+#     model = AutoModel(model="./Voice_translation", model_revision="v2.0.4",
+#                     vad_model="./Endpoint_detection", vad_model_revision="v2.0.4",
+#                     punc_model="./Ct_punc", punc_model_revision="v2.0.4",
+#                     use_cuda=True,use_fast = True,
+#                     )
+#     res = model.generate(input_video_path,
+#                 batch_size_s=30,
+#                 hotword='test')
 
-    
-    texts = [item['text'] for item in res]
 
-    
-    result = ' '.join(texts)
-    return result
+#     texts = [item['text'] for item in res]
 
 
-if  __name__ == "__main__":
-    start_time = time.time()
+#     result = ' '.join(texts)
+#     return result
 
 
-    model = AutoModel(model="./Voice_translation", model_revision="v2.0.4",
-                    vad_model="./Endpoint_detection", vad_model_revision="v2.0.4",
-                    punc_model="./Ct_punc", punc_model_revision="v2.0.4",
-                    )
-    res = model.generate(input="./data/audio/5bf77846-0193-4f35-92f7-09ce51ee3793.mp3", 
-                batch_size_s=30, 
-                hotword='test')
+# if  __name__ == "__main__":
+#     start_time = time.time()
 
-    print(res)
-    texts = [item['text'] for item in res]
 
-    print(texts)
-    result = ' '.join(texts)
-    print(result)
+#     model = AutoModel(model="./Voice_translation", model_revision="v2.0.4",
+#                     vad_model="./Endpoint_detection", vad_model_revision="v2.0.4",
+#                     punc_model="./Ct_punc", punc_model_revision="v2.0.4",
+#                     )
+#     res = model.generate(input="./data/audio/5bf77846-0193-4f35-92f7-09ce51ee3793.mp3",
+#                 batch_size_s=30,
+#                 hotword='test')
 
+#     print(res)
+#     texts = [item['text'] for item in res]
 
-# def save(input,savepath):    
-#     outputs = open(savepath, 'w', encoding='utf-8')
-#     outputs.write(input+'\n')
-#     outputs.close()
-# save(input=result,savepath=r"F:\work\voice_translation\datasets\1.txt")
+#     print(texts)
+#     result = ' '.join(texts)
+#     print(result)
 
-    end_time = time.time()
-    # 计算时间差
-    elapsed_time = end_time - start_time
 
-    print(f"耗时: {elapsed_time} 秒")
+# # def save(input,savepath):
+# #     outputs = open(savepath, 'w', encoding='utf-8')
+# #     outputs.write(input+'\n')
+# #     outputs.close()
+# # save(input=result,savepath=r"F:\work\voice_translation\datasets\1.txt")
+
+#     end_time = time.time()
+#     # 计算时间差
+#     elapsed_time = end_time - start_time
+
+#     print(f"耗时: {elapsed_time} 秒")

部分文件因文件數量過多而無法顯示