web_reader_tool.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import hashlib
  2. import json
  3. import os
  4. import re
  5. import site
  6. import subprocess
  7. import tempfile
  8. import unicodedata
  9. from contextlib import contextmanager
  10. from typing import Any
  11. import requests
  12. from bs4 import BeautifulSoup, CData, Comment, NavigableString
  13. from langchain.chains import RefineDocumentsChain
  14. from langchain.chains.summarize import refine_prompts
  15. from langchain.text_splitter import RecursiveCharacterTextSplitter
  16. from langchain.tools.base import BaseTool
  17. from newspaper import Article
  18. from pydantic import BaseModel, Field
  19. from regex import regex
  20. from core.chain.llm_chain import LLMChain
  21. from core.entities.application_entities import ModelConfigEntity
  22. from core.rag.extractor import extract_processor
  23. from core.rag.extractor.extract_processor import ExtractProcessor
  24. from core.rag.models.document import Document
  25. FULL_TEMPLATE = """
  26. TITLE: {title}
  27. AUTHORS: {authors}
  28. PUBLISH DATE: {publish_date}
  29. TOP_IMAGE_URL: {top_image}
  30. TEXT:
  31. {text}
  32. """
  33. class WebReaderToolInput(BaseModel):
  34. url: str = Field(..., description="URL of the website to read")
  35. summary: bool = Field(
  36. default=False,
  37. description="When the user's question requires extracting the summarizing content of the webpage, "
  38. "set it to true."
  39. )
  40. cursor: int = Field(
  41. default=0,
  42. description="Start reading from this character."
  43. "Use when the first response was truncated"
  44. "and you want to continue reading the page."
  45. "The value cannot exceed 24000.",
  46. )
  47. class WebReaderTool(BaseTool):
  48. """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
  49. name: str = "web_reader"
  50. args_schema: type[BaseModel] = WebReaderToolInput
  51. description: str = "use this to read a website. " \
  52. "If you can answer the question based on the information provided, " \
  53. "there is no need to use."
  54. page_contents: str = None
  55. url: str = None
  56. max_chunk_length: int = 4000
  57. summary_chunk_tokens: int = 4000
  58. summary_chunk_overlap: int = 0
  59. summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
  60. continue_reading: bool = True
  61. model_config: ModelConfigEntity
  62. model_parameters: dict[str, Any]
  63. def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
  64. try:
  65. if not self.page_contents or self.url != url:
  66. page_contents = get_url(url)
  67. self.page_contents = page_contents
  68. self.url = url
  69. else:
  70. page_contents = self.page_contents
  71. except Exception as e:
  72. return f'Read this website failed, caused by: {str(e)}.'
  73. if summary:
  74. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  75. chunk_size=self.summary_chunk_tokens,
  76. chunk_overlap=self.summary_chunk_overlap,
  77. separators=self.summary_separators
  78. )
  79. texts = character_splitter.split_text(page_contents)
  80. docs = [Document(page_content=t) for t in texts]
  81. if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
  82. return "No content found."
  83. # only use first 5 docs
  84. if len(docs) > 5:
  85. docs = docs[:5]
  86. chain = self.get_summary_chain()
  87. try:
  88. page_contents = chain.run(docs)
  89. except Exception as e:
  90. return f'Read this website failed, caused by: {str(e)}.'
  91. else:
  92. page_contents = page_result(page_contents, cursor, self.max_chunk_length)
  93. if self.continue_reading and len(page_contents) >= self.max_chunk_length:
  94. page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
  95. f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
  96. f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
  97. return page_contents
  98. async def _arun(self, url: str) -> str:
  99. raise NotImplementedError
  100. def get_summary_chain(self) -> RefineDocumentsChain:
  101. initial_chain = LLMChain(
  102. model_config=self.model_config,
  103. prompt=refine_prompts.PROMPT,
  104. parameters=self.model_parameters
  105. )
  106. refine_chain = LLMChain(
  107. model_config=self.model_config,
  108. prompt=refine_prompts.REFINE_PROMPT,
  109. parameters=self.model_parameters
  110. )
  111. return RefineDocumentsChain(
  112. initial_llm_chain=initial_chain,
  113. refine_llm_chain=refine_chain,
  114. document_variable_name="text",
  115. initial_response_name="existing_answer",
  116. callbacks=self.callbacks
  117. )
  118. def page_result(text: str, cursor: int, max_length: int) -> str:
  119. """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
  120. return text[cursor: cursor + max_length]
  121. def get_url(url: str, user_agent: str = None) -> str:
  122. """Fetch URL and return the contents as a string."""
  123. headers = {
  124. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
  125. }
  126. if user_agent:
  127. headers["User-Agent"] = user_agent
  128. supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
  129. head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
  130. if head_response.status_code != 200:
  131. return "URL returned status code {}.".format(head_response.status_code)
  132. # check content-type
  133. main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
  134. if main_content_type not in supported_content_types:
  135. return "Unsupported content-type [{}] of URL.".format(main_content_type)
  136. if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
  137. return ExtractProcessor.load_from_url(url, return_text=True)
  138. response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
  139. a = extract_using_readabilipy(response.text)
  140. if not a['plain_text'] or not a['plain_text'].strip():
  141. return get_url_from_newspaper3k(url)
  142. res = FULL_TEMPLATE.format(
  143. title=a['title'],
  144. authors=a['byline'],
  145. publish_date=a['date'],
  146. top_image="",
  147. text=a['plain_text'] if a['plain_text'] else "",
  148. )
  149. return res
  150. def get_url_from_newspaper3k(url: str) -> str:
  151. a = Article(url)
  152. a.download()
  153. a.parse()
  154. res = FULL_TEMPLATE.format(
  155. title=a.title,
  156. authors=a.authors,
  157. publish_date=a.publish_date,
  158. top_image=a.top_image,
  159. text=a.text,
  160. )
  161. return res
  162. def extract_using_readabilipy(html):
  163. with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
  164. f_html.write(html)
  165. f_html.close()
  166. html_path = f_html.name
  167. # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
  168. article_json_path = html_path + ".json"
  169. jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
  170. with chdir(jsdir):
  171. subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
  172. # Read output of call to Readability.parse() from JSON file and return as Python dictionary
  173. with open(article_json_path, encoding="utf-8") as json_file:
  174. input_json = json.loads(json_file.read())
  175. # Deleting files after processing
  176. os.unlink(article_json_path)
  177. os.unlink(html_path)
  178. article_json = {
  179. "title": None,
  180. "byline": None,
  181. "date": None,
  182. "content": None,
  183. "plain_content": None,
  184. "plain_text": None
  185. }
  186. # Populate article fields from readability fields where present
  187. if input_json:
  188. if "title" in input_json and input_json["title"]:
  189. article_json["title"] = input_json["title"]
  190. if "byline" in input_json and input_json["byline"]:
  191. article_json["byline"] = input_json["byline"]
  192. if "date" in input_json and input_json["date"]:
  193. article_json["date"] = input_json["date"]
  194. if "content" in input_json and input_json["content"]:
  195. article_json["content"] = input_json["content"]
  196. article_json["plain_content"] = plain_content(article_json["content"], False, False)
  197. article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
  198. if "textContent" in input_json and input_json["textContent"]:
  199. article_json["plain_text"] = input_json["textContent"]
  200. article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
  201. return article_json
  202. def find_module_path(module_name):
  203. for package_path in site.getsitepackages():
  204. potential_path = os.path.join(package_path, module_name)
  205. if os.path.exists(potential_path):
  206. return potential_path
  207. return None
  208. @contextmanager
  209. def chdir(path):
  210. """Change directory in context and return to original on exit"""
  211. # From https://stackoverflow.com/a/37996581, couldn't find a built-in
  212. original_path = os.getcwd()
  213. os.chdir(path)
  214. try:
  215. yield
  216. finally:
  217. os.chdir(original_path)
  218. def extract_text_blocks_as_plain_text(paragraph_html):
  219. # Load article as DOM
  220. soup = BeautifulSoup(paragraph_html, 'html.parser')
  221. # Select all lists
  222. list_elements = soup.find_all(['ul', 'ol'])
  223. # Prefix text in all list items with "* " and make lists paragraphs
  224. for list_element in list_elements:
  225. plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
  226. list_element.string = plain_items
  227. list_element.name = "p"
  228. # Select all text blocks
  229. text_blocks = [s.parent for s in soup.find_all(string=True)]
  230. text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
  231. # Drop empty paragraphs
  232. text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
  233. return text_blocks
  234. def plain_text_leaf_node(element):
  235. # Extract all text, stripped of any child HTML elements and normalise it
  236. plain_text = normalise_text(element.get_text())
  237. if plain_text != "" and element.name == "li":
  238. plain_text = "* {}, ".format(plain_text)
  239. if plain_text == "":
  240. plain_text = None
  241. if "data-node-index" in element.attrs:
  242. plain = {"node_index": element["data-node-index"], "text": plain_text}
  243. else:
  244. plain = {"text": plain_text}
  245. return plain
  246. def plain_content(readability_content, content_digests, node_indexes):
  247. # Load article as DOM
  248. soup = BeautifulSoup(readability_content, 'html.parser')
  249. # Make all elements plain
  250. elements = plain_elements(soup.contents, content_digests, node_indexes)
  251. if node_indexes:
  252. # Add node index attributes to nodes
  253. elements = [add_node_indexes(element) for element in elements]
  254. # Replace article contents with plain elements
  255. soup.contents = elements
  256. return str(soup)
  257. def plain_elements(elements, content_digests, node_indexes):
  258. # Get plain content versions of all elements
  259. elements = [plain_element(element, content_digests, node_indexes)
  260. for element in elements]
  261. if content_digests:
  262. # Add content digest attribute to nodes
  263. elements = [add_content_digest(element) for element in elements]
  264. return elements
  265. def plain_element(element, content_digests, node_indexes):
  266. # For lists, we make each item plain text
  267. if is_leaf(element):
  268. # For leaf node elements, extract the text content, discarding any HTML tags
  269. # 1. Get element contents as text
  270. plain_text = element.get_text()
  271. # 2. Normalise the extracted text string to a canonical representation
  272. plain_text = normalise_text(plain_text)
  273. # 3. Update element content to be plain text
  274. element.string = plain_text
  275. elif is_text(element):
  276. if is_non_printing(element):
  277. # The simplified HTML may have come from Readability.js so might
  278. # have non-printing text (e.g. Comment or CData). In this case, we
  279. # keep the structure, but ensure that the string is empty.
  280. element = type(element)("")
  281. else:
  282. plain_text = element.string
  283. plain_text = normalise_text(plain_text)
  284. element = type(element)(plain_text)
  285. else:
  286. # If not a leaf node or leaf type call recursively on child nodes, replacing
  287. element.contents = plain_elements(element.contents, content_digests, node_indexes)
  288. return element
  289. def add_node_indexes(element, node_index="0"):
  290. # Can't add attributes to string types
  291. if is_text(element):
  292. return element
  293. # Add index to current element
  294. element["data-node-index"] = node_index
  295. # Add index to child elements
  296. for local_idx, child in enumerate(
  297. [c for c in element.contents if not is_text(c)], start=1):
  298. # Can't add attributes to leaf string types
  299. child_index = "{stem}.{local}".format(
  300. stem=node_index, local=local_idx)
  301. add_node_indexes(child, node_index=child_index)
  302. return element
  303. def normalise_text(text):
  304. """Normalise unicode and whitespace."""
  305. # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
  306. text = strip_control_characters(text)
  307. text = normalise_unicode(text)
  308. text = normalise_whitespace(text)
  309. return text
  310. def strip_control_characters(text):
  311. """Strip out unicode control characters which might break the parsing."""
  312. # Unicode control characters
  313. # [Cc]: Other, Control [includes new lines]
  314. # [Cf]: Other, Format
  315. # [Cn]: Other, Not Assigned
  316. # [Co]: Other, Private Use
  317. # [Cs]: Other, Surrogate
  318. control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
  319. retained_chars = ['\t', '\n', '\r', '\f']
  320. # Remove non-printing control characters
  321. return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
  322. def normalise_unicode(text):
  323. """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
  324. normal_form = "NFKC"
  325. text = unicodedata.normalize(normal_form, text)
  326. return text
  327. def normalise_whitespace(text):
  328. """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
  329. text = regex.sub(r"\s+", " ", text)
  330. # Remove leading and trailing whitespace
  331. text = text.strip()
  332. return text
  333. def is_leaf(element):
  334. return (element.name in ['p', 'li'])
  335. def is_text(element):
  336. return isinstance(element, NavigableString)
  337. def is_non_printing(element):
  338. return any(isinstance(element, _e) for _e in [Comment, CData])
  339. def add_content_digest(element):
  340. if not is_text(element):
  341. element["data-content-digest"] = content_digest(element)
  342. return element
  343. def content_digest(element):
  344. if is_text(element):
  345. # Hash
  346. trimmed_string = element.string.strip()
  347. if trimmed_string == "":
  348. digest = ""
  349. else:
  350. digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
  351. else:
  352. contents = element.contents
  353. num_contents = len(contents)
  354. if num_contents == 0:
  355. # No hash when no child elements exist
  356. digest = ""
  357. elif num_contents == 1:
  358. # If single child, use digest of child
  359. digest = content_digest(contents[0])
  360. else:
  361. # Build content digest from the "non-empty" digests of child nodes
  362. digest = hashlib.sha256()
  363. child_digests = list(
  364. filter(lambda x: x != "", [content_digest(content) for content in contents]))
  365. for child in child_digests:
  366. digest.update(child.encode('utf-8'))
  367. digest = digest.hexdigest()
  368. return digest