chart.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import matplotlib.pyplot as plt
  2. from fontTools.ttLib import TTFont
  3. from matplotlib.font_manager import findSystemFonts
  4. from core.tools.entities.values import ToolLabelEnum
  5. from core.tools.errors import ToolProviderCredentialValidationError
  6. from core.tools.provider.builtin.chart.tools.line import LinearChartTool
  7. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  8. # use a business theme
  9. plt.style.use('seaborn-v0_8-darkgrid')
  10. plt.rcParams['axes.unicode_minus'] = False
  11. def init_fonts():
  12. fonts = findSystemFonts()
  13. popular_unicode_fonts = [
  14. 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif',
  15. 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans',
  16. 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono'
  17. ]
  18. supported_fonts = []
  19. for font_path in fonts:
  20. try:
  21. font = TTFont(font_path)
  22. # get family name
  23. family_name = font['name'].getName(1, 3, 1).toUnicode()
  24. if family_name in popular_unicode_fonts:
  25. supported_fonts.append(family_name)
  26. except:
  27. pass
  28. plt.rcParams['font.family'] = 'sans-serif'
  29. # sort by order of popular_unicode_fonts
  30. for font in popular_unicode_fonts:
  31. if font in supported_fonts:
  32. plt.rcParams['font.sans-serif'] = font
  33. break
  34. init_fonts()
  35. class ChartProvider(BuiltinToolProviderController):
  36. def _validate_credentials(self, credentials: dict) -> None:
  37. try:
  38. LinearChartTool().fork_tool_runtime(
  39. runtime={
  40. "credentials": credentials,
  41. }
  42. ).invoke(
  43. user_id='',
  44. tool_parameters={
  45. "data": "1,3,5,7,9,2,4,6,8,10",
  46. },
  47. )
  48. except Exception as e:
  49. raise ToolProviderCredentialValidationError(str(e))
  50. def _get_tool_labels(self) -> list[ToolLabelEnum]:
  51. return [
  52. ToolLabelEnum.DESIGN, ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
  53. ]