瀏覽代碼

更新代码

gushoubang 10 月之前
父節點
當前提交
401247139d

+ 5 - 0
.env.sample

@@ -0,0 +1,5 @@
+TEMP_FOLDER = './_temp'
+CHROMA_PATH = "chroma"
+COLLECTION_NAME = 'siwei_ai'
+LLM_MODEL = 'qwen2:7b'
+TEXT_EMBEDDING_MODEL = 'nomic-embed-text'

二進制
__pycache__/embed.cpython-310.pyc


二進制
__pycache__/get_vector_db.cpython-310.pyc


二進制
__pycache__/query.cpython-310.pyc


+ 75 - 20
app.py

@@ -10,6 +10,16 @@ import os
 from vocal import voice_text
 from voice_translation_test import vocal_text
 from flask_cors import CORS
+from dotenv import load_dotenv
+from embed import embed
+from query import query
+from get_vector_db import get_vector_db
+import time
+
+load_dotenv()
+TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
+os.makedirs(TEMP_FOLDER, exist_ok=True)
+
 
 app = Flask(__name__)
 CORS(app)
@@ -35,6 +45,46 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
 # 后台接口
 
+@app.route('/embed', methods=['POST'])
+def route_embed():
+    start_time = time.time()
+    if 'file' not in request.files:
+        return jsonify({"error": "No file part"}), 400
+
+    file = request.files['file']
+
+    if file.filename == '':
+        return jsonify({"error": "No selected file"}), 400
+    
+    embedded = embed(file)
+    end_time = time.time()
+    print("Time taken for embedding: ", end_time - start_time)
+
+    if embedded:
+        return jsonify({"message": "File embedded successfully"}), 200
+
+    return jsonify({"error": "File embedded unsuccessfully"}), 400
+    
+
+
+def route_query(msg):
+    response = query(msg)
+    if response:
+        resObj = {}
+        resObj["data"] = response
+        resObj["code"] = 200
+        resObj["type"] = "answer"
+
+        return resObj
+
+    return {"error": "Something went wrong"}, 400
+
+@app.route('/delete', methods=['DELETE'])
+def route_delete():
+    db = get_vector_db()
+    db.delete_collection()
+
+    return jsonify({"message": "Collection deleted successfully"}), 200
 
 @app.route("/")
 def home():
@@ -174,35 +224,40 @@ def inputMsg():
     print(data['msg'])
 
     msg = data['msg']
+    type = data['type']
+    if type == 'selectLand':
     # 调用大模型解析
     # 这里调用大模型,并返回解析结果
 
     # 示例:用户输入一条消息
     # msg= "我计划在抱坡区选取适宜地块作为工业用地,要求其在城市开发边界内,离小学大于1000m,坡度小于25度,用地面积在80-100亩之间。"
-    res = update_chat_history(msg)
-    print(res)  # 打印生成的回复
-
-    addtress = ['抱坡区', '天涯区', '崖州区', '海棠区', '吉阳区']
-    land = ['园地', '耕地', '林地', '草地', '湿地', '公共卫生用地', '老年人社会福利用地', '儿童社会福利用地', '残疾人社会福利用地', '其他社会福利用地', '零售商业用地', '居住用地', '批发市场用地', '餐饮用地', '旅馆用地', '公用设施营业网点用地', '娱乐用地', '康体用地', '一类工业用地', '二类工业用地', '广播电视设施用地', '环卫用地', '消防用地', '干渠', '水工设施用地', '其他公用设施用地', '公园绿地', '防护绿地', '广场用地', '军事设施用地', '使领馆用地', '宗教用地', '文物古迹用地', '监教场所用地', '殡葬用地', '其他特殊用地', '河流水面', '湖泊水面', '水库水面', '坑塘水面', '沟渠', '冰川及常年积雪', '渔业基础设施用海', '增养殖用海', '捕捞海域', '工业用海', '盐田用海', '固体矿产用海', '油气用海', '可再生能源用海', '海底电缆管道用海', '港口用海', '农业设施建设用地', '耕地', '园地', '林地', '工矿用地', '畜禽养殖设施建设用地', '水产养殖设施建设用地', '城镇住宅用地', '草地', '湿地', '留白用地', '陆地水域', '游憩用海', '特殊用海', '特殊用地', '其他海域',  '绿地与开敞空间用地', '水田', '水浇地', '旱地', '果园', '茶园', '橡胶园', '其他园地', '乔木林地', '竹林地', '城镇社区服务设施用地', '农村宅基地', '农村社区服务设施用地', '机关团体用地', '科研用地', '文化用地', '教育用地', '体育用地', '医疗卫生用地', '社会福利用地', '商业用地', '商务金融用地',
-            '二类农村宅基地', '图书与展览用地', '文化活动用地', '高等教育用地', '中等职业教育用地', '体育训练用地', '其他交通设施用地', '供水用地', '排水用地', '供电用地', '供燃气用地', '供热用地', '通信用地', '邮政用地', '医院用地', '基层医疗卫生设施用地', '田间道', '盐碱地', '沙地', '裸土地', '裸岩石砾地', '村道用地', '村庄内部道路用地', '渔业用海', '工矿通信用海', '其他土地', '公共管理与公共服务用地', '仓储用地', '交通运输用地', '公用设施用地', '交通运输用海', '航运用海', '路桥隧道用海', '风景旅游用海', '文体休闲娱乐用海', '军事用海', '其他特殊用海', '空闲地', '田坎', '港口码头用地', '管道运输用地', '城市轨道交通用地', '城镇道路用地', '交通场站用地', '一类城镇住宅用地', '二类城镇住宅用地', '三类城镇住宅用地', '一类农村宅基地', '商业服务业用地', '三类工业用地', '一类物流仓储用地', '二类物流仓储用地', '三类物流仓储用地', '盐田', '对外交通场站用地', '公共交通场站用地', '社会停车场用地', '中小学用地', '幼儿园用地', '其他教育用地', '体育场馆用地', '灌木林地', '其他林地', '天然牧草地', '人工牧草地', '其他草地', '森林沼泽', '灌丛沼泽', '沼泽草地', '其他沼泽地', '沿海滩涂', '内陆滩涂', '红树林地', '乡村道路用地', '种植设施建设用地', '娱乐康体用地', '其他商业服务业用地', '工业用地', '采矿用地', '物流仓储用地', '储备库用地', '铁路用地', '公路用地', '机场用地']
-    json_res = res
-    if json_res != "未找到相关数据":
-        try:
-            json_res = json.loads(json_res)
-            districtName = json_res["districtName"]
-            landType = json_res["landType"]
-            # if landType != "未找到相关数据" and landType != "" and districtName  != "未找到相关数据"and districtName != "":
-            if landType in land and districtName in addtress:
-                json_res = jsonResToDict(json_res)
-            else:
+        res = update_chat_history(msg)
+        print(res)  # 打印生成的回复
+
+        addtress = ['抱坡区', '天涯区', '崖州区', '海棠区', '吉阳区']
+        land = ['园地', '耕地', '林地', '草地', '湿地', '公共卫生用地', '老年人社会福利用地', '儿童社会福利用地', '残疾人社会福利用地', '其他社会福利用地', '零售商业用地', '居住用地', '批发市场用地', '餐饮用地', '旅馆用地', '公用设施营业网点用地', '娱乐用地', '康体用地', '一类工业用地', '二类工业用地', '广播电视设施用地', '环卫用地', '消防用地', '干渠', '水工设施用地', '其他公用设施用地', '公园绿地', '防护绿地', '广场用地', '军事设施用地', '使领馆用地', '宗教用地', '文物古迹用地', '监教场所用地', '殡葬用地', '其他特殊用地', '河流水面', '湖泊水面', '水库水面', '坑塘水面', '沟渠', '冰川及常年积雪', '渔业基础设施用海', '增养殖用海', '捕捞海域', '工业用海', '盐田用海', '固体矿产用海', '油气用海', '可再生能源用海', '海底电缆管道用海', '港口用海', '农业设施建设用地', '耕地', '园地', '林地', '工矿用地', '畜禽养殖设施建设用地', '水产养殖设施建设用地', '城镇住宅用地', '草地', '湿地', '留白用地', '陆地水域', '游憩用海', '特殊用海', '特殊用地', '其他海域',  '绿地与开敞空间用地', '水田', '水浇地', '旱地', '果园', '茶园', '橡胶园', '其他园地', '乔木林地', '竹林地', '城镇社区服务设施用地', '农村宅基地', '农村社区服务设施用地', '机关团体用地', '科研用地', '文化用地', '教育用地', '体育用地', '医疗卫生用地', '社会福利用地', '商业用地', '商务金融用地',
+                '二类农村宅基地', '图书与展览用地', '文化活动用地', '高等教育用地', '中等职业教育用地', '体育训练用地', '其他交通设施用地', '供水用地', '排水用地', '供电用地', '供燃气用地', '供热用地', '通信用地', '邮政用地', '医院用地', '基层医疗卫生设施用地', '田间道', '盐碱地', '沙地', '裸土地', '裸岩石砾地', '村道用地', '村庄内部道路用地', '渔业用海', '工矿通信用海', '其他土地', '公共管理与公共服务用地', '仓储用地', '交通运输用地', '公用设施用地', '交通运输用海', '航运用海', '路桥隧道用海', '风景旅游用海', '文体休闲娱乐用海', '军事用海', '其他特殊用海', '空闲地', '田坎', '港口码头用地', '管道运输用地', '城市轨道交通用地', '城镇道路用地', '交通场站用地', '一类城镇住宅用地', '二类城镇住宅用地', '三类城镇住宅用地', '一类农村宅基地', '商业服务业用地', '三类工业用地', '一类物流仓储用地', '二类物流仓储用地', '三类物流仓储用地', '盐田', '对外交通场站用地', '公共交通场站用地', '社会停车场用地', '中小学用地', '幼儿园用地', '其他教育用地', '体育场馆用地', '灌木林地', '其他林地', '天然牧草地', '人工牧草地', '其他草地', '森林沼泽', '灌丛沼泽', '沼泽草地', '其他沼泽地', '沿海滩涂', '内陆滩涂', '红树林地', '乡村道路用地', '种植设施建设用地', '娱乐康体用地', '其他商业服务业用地', '工业用地', '采矿用地', '物流仓储用地', '储备库用地', '铁路用地', '公路用地', '机场用地']
+        json_res = res
+        if json_res != "未找到相关数据":
+            try:
+                json_res = json.loads(json_res)
+                districtName = json_res["districtName"]
+                landType = json_res["landType"]
+                # if landType != "未找到相关数据" and landType != "" and districtName  != "未找到相关数据"and districtName != "":
+                if landType in land and districtName in addtress:
+                    json_res = jsonResToDict(json_res)
+                else:
+                    json_res = "未找到相关数据"
+                    json_res = jsonResToDict_wrong(json_res)
+            except:
                 json_res = "未找到相关数据"
                 json_res = jsonResToDict_wrong(json_res)
-        except:
+        else:
             json_res = "未找到相关数据"
             json_res = jsonResToDict_wrong(json_res)
-    else:
-        json_res = "未找到相关数据"
-        json_res = jsonResToDict_wrong(json_res)
+    elif type == 'answer':
+        json_res = route_query(msg)
+
     # 返回响应
     return jsonify(json_res)
 

二進制
chroma/cd7cb5a8-0622-4833-a6a2-d812be9d5da4/data_level0.bin


二進制
chroma/cd7cb5a8-0622-4833-a6a2-d812be9d5da4/header.bin


二進制
chroma/cd7cb5a8-0622-4833-a6a2-d812be9d5da4/length.bin


+ 0 - 0
chroma/cd7cb5a8-0622-4833-a6a2-d812be9d5da4/link_lists.bin


二進制
chroma/chroma.sqlite3


+ 48 - 0
embed.py

@@ -0,0 +1,48 @@
+import os
+from datetime import datetime
+from werkzeug.utils import secure_filename
+from langchain_community.document_loaders import UnstructuredPDFLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from get_vector_db import get_vector_db
+
+TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
+
+# Function to check if the uploaded file is allowed (only PDF files)
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'pdf'}
+
+# Function to save the uploaded file to the temporary folder
+def save_file(file):
+    # Save the uploaded file with a secure filename and return the file path
+    ct = datetime.now()
+    ts = ct.timestamp()
+    filename = str(ts) + "_" + secure_filename(file.filename)
+    file_path = os.path.join(TEMP_FOLDER, filename)
+    file.save(file_path)
+
+    return file_path
+
+# Function to load and split the data from the PDF file
+def load_and_split_data(file_path):
+    # Load the PDF file and split the data into chunks
+    loader = UnstructuredPDFLoader(file_path=file_path)
+    data = loader.load()
+    text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
+    chunks = text_splitter.split_documents(data)
+
+    return chunks
+
+# Main function to handle the embedding process
+def embed(file):
+    # Check if the file is valid, save it, load and split the data, add to the database, and remove the temporary file
+    if file.filename != '' and file and allowed_file(file.filename):
+        file_path = save_file(file)
+        chunks = load_and_split_data(file_path)
+        db = get_vector_db()
+        db.add_documents(chunks)
+        db.persist()
+        os.remove(file_path)
+
+        return True
+
+    return False

+ 18 - 0
get_vector_db.py

@@ -0,0 +1,18 @@
+import os
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.vectorstores.chroma import Chroma
+
+CHROMA_PATH = os.getenv('CHROMA_PATH', 'chroma')
+COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'siwei_ai')
+TEXT_EMBEDDING_MODEL = os.getenv('TEXT_EMBEDDING_MODEL', 'nomic-embed-text')
+
+def get_vector_db():
+    embedding = OllamaEmbeddings(model=TEXT_EMBEDDING_MODEL,show_progress=True)
+
+    db = Chroma(
+        collection_name=COLLECTION_NAME,
+        persist_directory=CHROMA_PATH,
+        embedding_function=embedding
+    )
+
+    return db

+ 61 - 0
query.py

@@ -0,0 +1,61 @@
+import os
+from langchain_community.chat_models import ChatOllama
+from langchain.prompts import ChatPromptTemplate, PromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain.retrievers.multi_query import MultiQueryRetriever
+from get_vector_db import get_vector_db
+
+LLM_MODEL = os.getenv('LLM_MODEL', 'qwen2:7b')
+
+# Function to get the prompt templates for generating alternative questions and answering based on context
+def get_prompt():
+    QUERY_PROMPT = PromptTemplate(
+        input_variables=["question"],
+        template="""You are an AI language model assistant. Your task is to generate five
+        different versions of the given user question to retrieve relevant documents from
+        a vector database. By generating multiple perspectives on the user question, your
+        goal is to help the user overcome some of the limitations of the distance-based
+        similarity search. Provide these alternative questions separated by newlines.
+        Original question: {question}""",
+    )
+
+    template = """Answer the question in Chinese based ONLY on the following context:
+    {context}
+    Question: {question}
+    """
+
+    prompt = ChatPromptTemplate.from_template(template)
+
+    return QUERY_PROMPT, prompt
+
+# Main function to handle the query process
+def query(input):
+    if input:
+        # Initialize the language model with the specified model name
+        llm = ChatOllama(model=LLM_MODEL)
+        # Get the vector database instance
+        db = get_vector_db()
+        # Get the prompt templates
+        QUERY_PROMPT, prompt = get_prompt()
+
+        # Set up the retriever to generate multiple queries using the language model and the query prompt
+        retriever = MultiQueryRetriever.from_llm(
+            db.as_retriever(), 
+            llm,
+            prompt=QUERY_PROMPT
+        )
+
+        # Define the processing chain to retrieve context, generate the answer, and parse the output
+        chain = (
+            {"context": retriever, "question": RunnablePassthrough()}
+            | prompt
+            | llm
+            | StrOutputParser()
+        )
+
+        response = chain.invoke(input)
+        
+        return response
+
+    return None

+ 190 - 0
requirements1.txt

@@ -0,0 +1,190 @@
+aiohttp==3.9.5
+aiosignal==1.3.1
+annotated-types==0.7.0
+antlr4-python3-runtime==4.9.3
+anyio==4.4.0
+asgiref==3.8.1
+attrs==23.2.0
+backoff==2.2.1
+bcrypt==4.1.3
+beautifulsoup4==4.12.3
+blinker==1.8.2
+build==1.2.1
+cachetools==5.3.3
+certifi==2024.6.2
+cffi==1.16.0
+chardet==5.2.0
+charset-normalizer==3.3.2
+chroma-hnswlib==0.7.3
+chromadb==0.5.3
+click==8.1.7
+coloredlogs==15.0.1
+contourpy==1.2.1
+cryptography==42.0.8
+cycler==0.12.1
+dataclasses-json==0.6.7
+deepdiff==7.0.1
+Deprecated==1.2.14
+dnspython==2.6.1
+effdet==0.4.1
+email_validator==2.2.0
+emoji==2.12.1
+et-xmlfile==1.1.0
+fastapi==0.111.0
+fastapi-cli==0.0.4
+filelock==3.15.4
+filetype==1.2.0
+Flask==3.0.3
+flatbuffers==24.3.25
+fonttools==4.53.0
+frozenlist==1.4.1
+fsspec==2024.6.1
+google-api-core==2.19.1
+google-auth==2.30.0
+google-cloud-vision==3.7.2
+googleapis-common-protos==1.63.2
+grpcio==1.64.1
+grpcio-status==1.62.2
+h11==0.14.0
+httpcore==1.0.5
+httptools==0.6.1
+httpx==0.27.0
+huggingface-hub==0.23.4
+humanfriendly==10.0
+idna==3.7
+importlib_metadata==7.1.0
+importlib_resources==6.4.0
+iopath==0.1.10
+itsdangerous==2.2.0
+Jinja2==3.1.4
+joblib==1.4.2
+jsonpatch==1.33
+jsonpath-python==1.0.6
+jsonpointer==3.0.0
+kiwisolver==1.4.5
+kubernetes==30.1.0
+langchain==0.2.6
+langchain-community==0.2.6
+langchain-core==0.2.10
+langchain-text-splitters==0.2.2
+langdetect==1.0.9
+langsmith==0.1.82
+layoutparser==0.3.4
+lxml==5.2.2
+Markdown==3.6
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+marshmallow==3.21.3
+matplotlib==3.9.0
+mdurl==0.1.2
+mmh3==4.1.0
+monotonic==1.6
+mpmath==1.3.0
+multidict==6.0.5
+mypy-extensions==1.0.0
+nest-asyncio==1.6.0
+networkx==3.3
+nltk==3.8.1
+numpy==1.26.4
+oauthlib==3.2.2
+olefile==0.47
+omegaconf==2.3.0
+onnx==1.16.1
+onnxruntime==1.18.1
+opencv-python==4.10.0.84
+openpyxl==3.1.5
+opentelemetry-api==1.25.0
+opentelemetry-exporter-otlp-proto-common==1.25.0
+opentelemetry-exporter-otlp-proto-grpc==1.25.0
+opentelemetry-instrumentation==0.46b0
+opentelemetry-instrumentation-asgi==0.46b0
+opentelemetry-instrumentation-fastapi==0.46b0
+opentelemetry-proto==1.25.0
+opentelemetry-sdk==1.25.0
+opentelemetry-semantic-conventions==0.46b0
+opentelemetry-util-http==0.46b0
+ordered-set==4.1.0
+orjson==3.10.5
+overrides==7.7.0
+packaging==24.1
+pandas==2.2.2
+pdf2image==1.17.0
+pdfminer.six==20231228
+pdfplumber==0.11.1
+pikepdf==9.0.0
+pillow==10.3.0
+pillow_heif==0.16.0
+portalocker==2.10.0
+posthog==3.5.0
+proto-plus==1.24.0
+protobuf==4.25.3
+psutil==6.0.0
+pyasn1==0.6.0
+pyasn1_modules==0.4.0
+pycocotools==2.0.8
+pycparser==2.22
+pydantic==2.7.4
+pydantic_core==2.18.4
+Pygments==2.18.0
+pypandoc==1.13
+pyparsing==3.1.2
+pypdf==4.2.0
+pypdfium2==4.30.0
+PyPika==0.48.9
+pyproject_hooks==1.1.0
+pytesseract==0.3.10
+python-dateutil==2.9.0.post0
+python-docx==1.1.2
+python-dotenv==1.0.1
+python-iso639==2024.4.27
+python-magic==0.4.27
+python-multipart==0.0.9
+python-oxmsg==0.0.1
+python-pptx==0.6.23
+pytz==2024.1
+PyYAML==6.0.1
+rapidfuzz==3.9.3
+regex==2024.5.15
+requests==2.32.3
+requests-oauthlib==2.0.0
+requests-toolbelt==1.0.0
+rich==13.7.1
+rsa==4.9
+safetensors==0.4.3
+scipy==1.14.0
+shellingham==1.5.4
+six==1.16.0
+sniffio==1.3.1
+soupsieve==2.5
+SQLAlchemy==2.0.31
+starlette==0.37.2
+sympy==1.12.1
+tabulate==0.9.0
+tenacity==8.4.2
+timm==1.0.7
+tokenizers==0.19.1
+# torch==2.3.1
+# torchvision==0.18.1
+tqdm==4.66.4
+transformers==4.42.3
+typer==0.12.3
+typing-inspect==0.9.0
+typing_extensions==4.12.2
+tzdata==2024.1
+ujson==5.10.0
+unstructured==0.14.9
+unstructured-client==0.23.8
+unstructured-inference==0.7.36
+unstructured.pytesseract==0.3.12
+urllib3==2.2.2
+uvicorn==0.30.1
+# uvloop==0.19.0
+watchfiles==0.22.0
+websocket-client==1.8.0
+websockets==12.0
+Werkzeug==3.0.3
+wrapt==1.16.0
+xlrd==2.0.1
+XlsxWriter==3.2.0
+yarl==1.9.4
+zipp==3.19.2