|
@@ -1,3 +1,4 @@
|
|
|
|
+import json
|
|
import logging
|
|
import logging
|
|
import time
|
|
import time
|
|
from collections.abc import Mapping
|
|
from collections.abc import Mapping
|
|
@@ -8,6 +9,7 @@ from requests.exceptions import HTTPError
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
+
|
|
class FirecrawlApp:
|
|
class FirecrawlApp:
|
|
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
|
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
|
self.api_key = api_key
|
|
self.api_key = api_key
|
|
@@ -25,14 +27,16 @@ class FirecrawlApp:
|
|
return headers
|
|
return headers
|
|
|
|
|
|
def _request(
|
|
def _request(
|
|
- self,
|
|
|
|
- method: str,
|
|
|
|
- url: str,
|
|
|
|
- data: Mapping[str, Any] | None = None,
|
|
|
|
- headers: Mapping[str, str] | None = None,
|
|
|
|
- retries: int = 3,
|
|
|
|
- backoff_factor: float = 0.3,
|
|
|
|
|
|
+ self,
|
|
|
|
+ method: str,
|
|
|
|
+ url: str,
|
|
|
|
+ data: Mapping[str, Any] | None = None,
|
|
|
|
+ headers: Mapping[str, str] | None = None,
|
|
|
|
+ retries: int = 3,
|
|
|
|
+ backoff_factor: float = 0.3,
|
|
) -> Mapping[str, Any] | None:
|
|
) -> Mapping[str, Any] | None:
|
|
|
|
+ if not headers:
|
|
|
|
+ headers = self._prepare_headers()
|
|
for i in range(retries):
|
|
for i in range(retries):
|
|
try:
|
|
try:
|
|
response = requests.request(method, url, json=data, headers=headers)
|
|
response = requests.request(method, url, json=data, headers=headers)
|
|
@@ -47,47 +51,51 @@ class FirecrawlApp:
|
|
|
|
|
|
def scrape_url(self, url: str, **kwargs):
|
|
def scrape_url(self, url: str, **kwargs):
|
|
endpoint = f'{self.base_url}/v0/scrape'
|
|
endpoint = f'{self.base_url}/v0/scrape'
|
|
- headers = self._prepare_headers()
|
|
|
|
data = {'url': url, **kwargs}
|
|
data = {'url': url, **kwargs}
|
|
- response = self._request('POST', endpoint, data, headers)
|
|
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
|
|
+ response = self._request('POST', endpoint, data)
|
|
if response is None:
|
|
if response is None:
|
|
raise HTTPError("Failed to scrape URL after multiple retries")
|
|
raise HTTPError("Failed to scrape URL after multiple retries")
|
|
return response
|
|
return response
|
|
|
|
|
|
def search(self, query: str, **kwargs):
|
|
def search(self, query: str, **kwargs):
|
|
endpoint = f'{self.base_url}/v0/search'
|
|
endpoint = f'{self.base_url}/v0/search'
|
|
- headers = self._prepare_headers()
|
|
|
|
data = {'query': query, **kwargs}
|
|
data = {'query': query, **kwargs}
|
|
- response = self._request('POST', endpoint, data, headers)
|
|
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
|
|
+ response = self._request('POST', endpoint, data)
|
|
if response is None:
|
|
if response is None:
|
|
raise HTTPError("Failed to perform search after multiple retries")
|
|
raise HTTPError("Failed to perform search after multiple retries")
|
|
return response
|
|
return response
|
|
|
|
|
|
def crawl_url(
|
|
def crawl_url(
|
|
- self, url: str, wait: bool = False, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
|
|
|
|
|
|
+ self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
|
|
):
|
|
):
|
|
endpoint = f'{self.base_url}/v0/crawl'
|
|
endpoint = f'{self.base_url}/v0/crawl'
|
|
headers = self._prepare_headers(idempotency_key)
|
|
headers = self._prepare_headers(idempotency_key)
|
|
- data = {'url': url, **kwargs['params']}
|
|
|
|
- response = self._request('POST', endpoint, data, headers)
|
|
|
|
|
|
+ data = {'url': url, **kwargs}
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
logger.debug(f"Sent request to {endpoint=} body={data}")
|
|
|
|
+ response = self._request('POST', endpoint, data, headers)
|
|
if response is None:
|
|
if response is None:
|
|
raise HTTPError("Failed to initiate crawl after multiple retries")
|
|
raise HTTPError("Failed to initiate crawl after multiple retries")
|
|
job_id: str = response['jobId']
|
|
job_id: str = response['jobId']
|
|
if wait:
|
|
if wait:
|
|
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
|
|
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
|
|
- return job_id
|
|
|
|
|
|
+ return response
|
|
|
|
|
|
def check_crawl_status(self, job_id: str):
|
|
def check_crawl_status(self, job_id: str):
|
|
endpoint = f'{self.base_url}/v0/crawl/status/{job_id}'
|
|
endpoint = f'{self.base_url}/v0/crawl/status/{job_id}'
|
|
- headers = self._prepare_headers()
|
|
|
|
- response = self._request('GET', endpoint, headers=headers)
|
|
|
|
|
|
+ response = self._request('GET', endpoint)
|
|
if response is None:
|
|
if response is None:
|
|
raise HTTPError(f"Failed to check status for job {job_id} after multiple retries")
|
|
raise HTTPError(f"Failed to check status for job {job_id} after multiple retries")
|
|
return response
|
|
return response
|
|
|
|
|
|
|
|
+ def cancel_crawl_job(self, job_id: str):
|
|
|
|
+ endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}'
|
|
|
|
+ response = self._request('DELETE', endpoint)
|
|
|
|
+ if response is None:
|
|
|
|
+ raise HTTPError(f"Failed to cancel job {job_id} after multiple retries")
|
|
|
|
+ return response
|
|
|
|
+
|
|
def _monitor_job_status(self, job_id: str, poll_interval: int):
|
|
def _monitor_job_status(self, job_id: str, poll_interval: int):
|
|
while True:
|
|
while True:
|
|
status = self.check_crawl_status(job_id)
|
|
status = self.check_crawl_status(job_id)
|
|
@@ -96,3 +104,21 @@ class FirecrawlApp:
|
|
elif status['status'] == 'failed':
|
|
elif status['status'] == 'failed':
|
|
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
|
|
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
|
|
time.sleep(poll_interval)
|
|
time.sleep(poll_interval)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_array_params(tool_parameters: dict[str, Any], key):
|
|
|
|
+ param = tool_parameters.get(key)
|
|
|
|
+ if param:
|
|
|
|
+ return param.split(',')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_json_params(tool_parameters: dict[str, Any], key):
|
|
|
|
+ param = tool_parameters.get(key)
|
|
|
|
+ if param:
|
|
|
|
+ try:
|
|
|
|
+ # support both single quotes and double quotes
|
|
|
|
+ param = param.replace("'", '"')
|
|
|
|
+ param = json.loads(param)
|
|
|
|
+ except:
|
|
|
|
+ raise ValueError(f"Invalid {key} format.")
|
|
|
|
+ return param
|