|  | @@ -1,11 +1,15 @@
 | 
	
		
			
				|  |  |  import array
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  | +import re
 | 
	
		
			
				|  |  |  import uuid
 | 
	
		
			
				|  |  |  from contextlib import contextmanager
 | 
	
		
			
				|  |  |  from typing import Any
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import jieba.posseg as pseg
 | 
	
		
			
				|  |  | +import nltk
 | 
	
		
			
				|  |  |  import numpy
 | 
	
		
			
				|  |  |  import oracledb
 | 
	
		
			
				|  |  | +from nltk.corpus import stopwords
 | 
	
		
			
				|  |  |  from pydantic import BaseModel, model_validator
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from configs import dify_config
 | 
	
	
		
			
				|  | @@ -50,6 +54,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
 | 
	
		
			
				|  |  |      ,embedding vector NOT NULL
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  """
 | 
	
		
			
				|  |  | +SQL_CREATE_INDEX = """
 | 
	
		
			
				|  |  | +CREATE INDEX idx_docs_{table_name} ON {table_name}(text) 
 | 
	
		
			
				|  |  | +INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS 
 | 
	
		
			
				|  |  | +('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer')
 | 
	
		
			
				|  |  | +"""
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class OracleVector(BaseVector):
 | 
	
	
		
			
				|  | @@ -188,7 +197,53 @@ class OracleVector(BaseVector):
 | 
	
		
			
				|  |  |          return docs
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
 | 
	
		
			
				|  |  | -        # do not support bm25 search
 | 
	
		
			
				|  |  | +        top_k = kwargs.get("top_k", 5)
 | 
	
		
			
				|  |  | +        # just not implement fetch by score_threshold now, may be later
 | 
	
		
			
				|  |  | +        score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
 | 
	
		
			
				|  |  | +        if len(query) > 0:
 | 
	
		
			
				|  |  | +            # Check which language the query is in
 | 
	
		
			
				|  |  | +            zh_pattern = re.compile('[\u4e00-\u9fa5]+')
 | 
	
		
			
				|  |  | +            match = zh_pattern.search(query)
 | 
	
		
			
				|  |  | +            entities = []
 | 
	
		
			
				|  |  | +            #  match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
 | 
	
		
			
				|  |  | +            if match:
 | 
	
		
			
				|  |  | +                words = pseg.cut(query)
 | 
	
		
			
				|  |  | +                current_entity = ""
 | 
	
		
			
				|  |  | +                for word, pos in words:
 | 
	
		
			
				|  |  | +                    if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v':  # nr: 人名, ns: 地名, nt: 机构名
 | 
	
		
			
				|  |  | +                        current_entity += word
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                        if current_entity:
 | 
	
		
			
				|  |  | +                            entities.append(current_entity)
 | 
	
		
			
				|  |  | +                            current_entity = ""
 | 
	
		
			
				|  |  | +                if current_entity:
 | 
	
		
			
				|  |  | +                    entities.append(current_entity)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                try:
 | 
	
		
			
				|  |  | +                    nltk.data.find('tokenizers/punkt')
 | 
	
		
			
				|  |  | +                    nltk.data.find('corpora/stopwords')
 | 
	
		
			
				|  |  | +                except LookupError:
 | 
	
		
			
				|  |  | +                    nltk.download('punkt')
 | 
	
		
			
				|  |  | +                    nltk.download('stopwords')
 | 
	
		
			
				|  |  | +                    print("run download")
 | 
	
		
			
				|  |  | +                e_str = re.sub(r'[^\w ]', '', query)
 | 
	
		
			
				|  |  | +                all_tokens = nltk.word_tokenize(e_str)
 | 
	
		
			
				|  |  | +                stop_words = stopwords.words('english')
 | 
	
		
			
				|  |  | +                for token in all_tokens:
 | 
	
		
			
				|  |  | +                    if token not in stop_words:
 | 
	
		
			
				|  |  | +                        entities.append(token)
 | 
	
		
			
				|  |  | +            with self._get_cursor() as cur:
 | 
	
		
			
				|  |  | +                cur.execute(
 | 
	
		
			
				|  |  | +                    f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
 | 
	
		
			
				|  |  | +                    [" ACCUM ".join(entities)]
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +                docs = []
 | 
	
		
			
				|  |  | +                for record in cur:
 | 
	
		
			
				|  |  | +                    metadata, text = record
 | 
	
		
			
				|  |  | +                    docs.append(Document(page_content=text, metadata=metadata))
 | 
	
		
			
				|  |  | +            return docs
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return [Document(page_content="", metadata="")]
 | 
	
		
			
				|  |  |          return []
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def delete(self) -> None:
 | 
	
	
		
			
				|  | @@ -206,6 +261,8 @@ class OracleVector(BaseVector):
 | 
	
		
			
				|  |  |              with self._get_cursor() as cur:
 | 
	
		
			
				|  |  |                  cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
 | 
	
		
			
				|  |  |              redis_client.set(collection_exist_cache_key, 1, ex=3600)
 | 
	
		
			
				|  |  | +            with self._get_cursor() as cur:
 | 
	
		
			
				|  |  | +                cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class OracleVectorFactory(AbstractVectorFactory):
 |