embed.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. from datetime import datetime
  3. from werkzeug.utils import secure_filename
  4. from langchain_community.document_loaders import UnstructuredPDFLoader
  5. from langchain_text_splitters import RecursiveCharacterTextSplitter
  6. from llm_model.get_vector_db import get_vector_db
  7. TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
  8. # Function to check if the uploaded file is allowed (only PDF files)
  9. def allowed_file(filename):
  10. return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'pdf'}
  11. # Function to save the uploaded file to the temporary folder
  12. def save_file(file):
  13. # Save the uploaded file with a secure filename and return the file path
  14. ct = datetime.now()
  15. ts = ct.timestamp()
  16. filename = str(ts) + "_" + secure_filename(file.filename)
  17. file_path = os.path.join(TEMP_FOLDER, filename)
  18. file.save(file_path)
  19. return file_path
  20. # Function to load and split the data from the PDF file
  21. def load_and_split_data(file_path):
  22. # Load the PDF file and split the data into chunks
  23. loader = UnstructuredPDFLoader(file_path=file_path)
  24. data = loader.load()
  25. text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
  26. chunks = text_splitter.split_documents(data)
  27. return chunks
  28. # Main function to handle the embedding process
  29. def embed(file):
  30. # Check if the file is valid, save it, load and split the data, add to the database, and remove the temporary file
  31. if file.filename != '' and file and allowed_file(file.filename):
  32. file_path = save_file(file)
  33. chunks = load_and_split_data(file_path)
  34. db = get_vector_db()
  35. db.add_documents(chunks)
  36. db.persist()
  37. os.remove(file_path)
  38. return True
  39. return False