connector.py 32 KB


  1. """
  2. /***************************************************************************
  3. Name : DB Manager
  4. Description : Database manager plugin for QGIS
  5. Date : Oct 14 2016
  6. copyright : (C) 2016 by Even Rouault
  7. (C) 2011 by Giuseppe Sucameli
  8. email : even.rouault at spatialys.com
  9. ***************************************************************************/
  10. /***************************************************************************
  11. * *
  12. * This program is free software; you can redistribute it and/or modify *
  13. * it under the terms of the GNU General Public License as published by *
  14. * the Free Software Foundation; either version 2 of the License, or *
  15. * (at your option) any later version. *
  16. * *
  17. ***************************************************************************/
  18. """
  19. from functools import cmp_to_key
  20. from qgis.PyQt.QtWidgets import QApplication
  21. from qgis.PyQt.QtCore import QThread
  22. from ..connector import DBConnector
  23. from ..plugin import ConnectionError, DbError, Table
  24. from qgis.utils import spatialite_connect
  25. from qgis.core import (
  26. QgsApplication,
  27. QgsProviderRegistry,
  28. QgsAbstractDatabaseProviderConnection,
  29. QgsProviderConnectionException,
  30. QgsWkbTypes,
  31. )
  32. import sqlite3
  33. from osgeo import gdal, ogr, osr
  34. def classFactory():
  35. return GPKGDBConnector
  36. class GPKGDBConnector(DBConnector):
  37. def __init__(self, uri, connection):
  38. """Creates a new GPKG connector
  39. :param uri: data source URI
  40. :type uri: QgsDataSourceUri
  41. :param connection: the GPKGDBPlugin parent instance
  42. :type connection: GPKGDBPlugin
  43. """
  44. DBConnector.__init__(self, uri)
  45. self.dbname = uri.database()
  46. self.connection = connection
  47. self._current_thread = None
  48. md = QgsProviderRegistry.instance().providerMetadata(connection.providerName())
  49. # QgsAbstractDatabaseProviderConnection instance
  50. self.core_connection = md.findConnection(connection.connectionName())
  51. if self.core_connection is None:
  52. self.core_connection = md.createConnection(uri.uri(), {})
  53. self.has_raster = False
  54. self.mapSridToName = {}
  55. # To be removed when migration to new API is completed
  56. self._opendb()
  57. def _opendb(self):
  58. # Keep this explicit assignment to None to make sure the file is
  59. # properly closed before being re-opened
  60. self.gdal_ds = None
  61. self.gdal_ds = gdal.OpenEx(self.dbname, gdal.OF_UPDATE)
  62. if self.gdal_ds is None:
  63. self.gdal_ds = gdal.OpenEx(self.dbname)
  64. if self.gdal_ds is None:
  65. raise ConnectionError(QApplication.translate("DBManagerPlugin", '"{0}" not found').format(self.dbname))
  66. if self.gdal_ds.GetDriver().ShortName != 'GPKG':
  67. raise ConnectionError(QApplication.translate("DBManagerPlugin", '"{dbname}" not recognized as GPKG ({shortname} reported instead.)').format(dbname=self.dbname, shortname=self.gdal_ds.GetDriver().ShortName))
  68. self.has_raster = self.gdal_ds.RasterCount != 0 or self.gdal_ds.GetMetadata('SUBDATASETS') is not None
  69. self.connection = None
  70. self._current_thread = None
  71. @property
  72. def connection(self):
  73. """Creates and returns a spatialite connection, if
  74. the existing connection was created in another thread
  75. invalidates it and create a new one.
  76. """
  77. if self._connection is None or self._current_thread != int(QThread.currentThreadId()):
  78. self._current_thread = int(QThread.currentThreadId())
  79. try:
  80. self._connection = spatialite_connect(str(self.dbname))
  81. except self.connection_error_types() as e:
  82. raise ConnectionError(e)
  83. return self._connection
  84. @connection.setter
  85. def connection(self, conn):
  86. self._connection = conn
  87. def unquoteId(self, quotedId):
  88. if len(quotedId) <= 2 or quotedId[0] != '"' or quotedId[len(quotedId) - 1] != '"':
  89. return quotedId
  90. unquoted = ''
  91. i = 1
  92. while i < len(quotedId) - 1:
  93. if quotedId[i] == '"' and quotedId[i + 1] == '"':
  94. unquoted += '"'
  95. i += 2
  96. else:
  97. unquoted += quotedId[i]
  98. i += 1
  99. return unquoted
  100. def _fetchOne(self, sql):
  101. return self.core_connection.executeSql(sql)
  102. def _fetchAll(self, sql, include_fid_and_geometry=False):
  103. return self.core_connection.executeSql(sql)
  104. def _fetchAllFromLayer(self, table):
  105. lyr = self.gdal_ds.GetLayerByName(table.name)
  106. if lyr is None:
  107. return []
  108. lyr.ResetReading()
  109. ret = []
  110. while True:
  111. f = lyr.GetNextFeature()
  112. if f is None:
  113. break
  114. else:
  115. field_vals = [f.GetFID()]
  116. if lyr.GetLayerDefn().GetGeomType() != ogr.wkbNone:
  117. geom = f.GetGeometryRef()
  118. if geom is not None:
  119. geom = geom.ExportToWkt()
  120. field_vals += [geom]
  121. field_vals += [f.GetField(i) for i in range(f.GetFieldCount())]
  122. ret.append(field_vals)
  123. return ret
  124. def _execute_and_commit(self, sql):
  125. sql_lyr = self.gdal_ds.ExecuteSQL(sql)
  126. self.gdal_ds.ReleaseResultSet(sql_lyr)
  127. def _execute(self, cursor, sql):
  128. if self.connection is None:
  129. # Needed when evaluating a SQL query
  130. try:
  131. self.connection = spatialite_connect(str(self.dbname))
  132. except self.connection_error_types() as e:
  133. raise ConnectionError(e)
  134. return DBConnector._execute(self, cursor, sql)
  135. def _commit(self):
  136. if self.connection is None:
  137. return
  138. try:
  139. self.connection.commit()
  140. except self.connection_error_types() as e:
  141. raise ConnectionError(e)
  142. except self.execution_error_types() as e:
  143. # do the rollback to avoid a "current transaction aborted, commands ignored" errors
  144. self._rollback()
  145. raise DbError(e)
  146. def cancel(self):
  147. if self.connection:
  148. self.connection.interrupt()
  149. @classmethod
  150. def isValidDatabase(cls, path):
  151. if hasattr(gdal, 'OpenEx'):
  152. ds = gdal.OpenEx(path)
  153. if ds is None or ds.GetDriver().ShortName != 'GPKG':
  154. return False
  155. else:
  156. ds = ogr.Open(path)
  157. if ds is None or ds.GetDriver().GetName() != 'GPKG':
  158. return False
  159. return True
  160. def getInfo(self):
  161. return None
  162. def getSpatialInfo(self):
  163. return None
  164. def hasSpatialSupport(self):
  165. return True
  166. # Used by DlgTableProperties
  167. def canAddGeometryColumn(self, table):
  168. _, tablename = self.getSchemaTableName(table)
  169. lyr = self.gdal_ds.GetLayerByName(tablename)
  170. if lyr is None:
  171. return False
  172. return lyr.GetGeomType() == ogr.wkbNone
  173. # Used by DlgTableProperties
  174. def canAddSpatialIndex(self, table):
  175. _, tablename = self.getSchemaTableName(table)
  176. lyr = self.gdal_ds.GetLayerByName(tablename)
  177. if lyr is None or lyr.GetGeometryColumn() == '':
  178. return False
  179. return not self.hasSpatialIndex(table,
  180. lyr.GetGeometryColumn())
  181. def hasRasterSupport(self):
  182. return self.has_raster
  183. def hasCustomQuerySupport(self):
  184. return True
  185. def hasTableColumnEditingSupport(self):
  186. return True
  187. def hasCreateSpatialViewSupport(self):
  188. return False
  189. def fieldTypes(self):
  190. # From "Table 1. GeoPackage Data Types" (http://www.geopackage.org/spec/)
  191. return [
  192. "TEXT",
  193. "MEDIUMINT",
  194. "INTEGER",
  195. "TINYINT",
  196. "SMALLINT",
  197. "DOUBLE",
  198. "FLOAT"
  199. "DATE",
  200. "DATETIME",
  201. "BOOLEAN",
  202. ]
  203. def getSchemas(self):
  204. return None
  205. def getTables(self, schema=None, add_sys_tables=False):
  206. """ get list of tables """
  207. items = []
  208. try:
  209. vectors = self.getVectorTables(schema)
  210. for tbl in vectors:
  211. items.append(tbl)
  212. except DbError:
  213. pass
  214. try:
  215. rasters = self.getRasterTables(schema)
  216. for tbl in rasters:
  217. items.append(tbl)
  218. except DbError:
  219. pass
  220. for i, tbl in enumerate(items):
  221. tbl.insert(3, False) # not system table
  222. return sorted(items, key=cmp_to_key(lambda x, y: (x[1] > y[1]) - (x[1] < y[1])))
  223. def getVectorTables(self, schema=None):
  224. """Returns a list of vector table information
  225. """
  226. items = []
  227. for table in self.core_connection.tables(schema, QgsAbstractDatabaseProviderConnection.Vector | QgsAbstractDatabaseProviderConnection.Aspatial):
  228. if not (table.flags() & QgsAbstractDatabaseProviderConnection.Aspatial):
  229. geom_type = table.geometryColumnTypes()[0]
  230. # Use integer PG code for SRID
  231. srid = geom_type.crs.postgisSrid()
  232. geomtype_flatten = QgsWkbTypes.flatType(geom_type.wkbType)
  233. geomname = 'GEOMETRY'
  234. if geomtype_flatten == QgsWkbTypes.Point:
  235. geomname = 'POINT'
  236. elif geomtype_flatten == QgsWkbTypes.LineString:
  237. geomname = 'LINESTRING'
  238. elif geomtype_flatten == QgsWkbTypes.Polygon:
  239. geomname = 'POLYGON'
  240. elif geomtype_flatten == QgsWkbTypes.MultiPoint:
  241. geomname = 'MULTIPOINT'
  242. elif geomtype_flatten == QgsWkbTypes.MultiLineString:
  243. geomname = 'MULTILINESTRING'
  244. elif geomtype_flatten == QgsWkbTypes.MultiPolygon:
  245. geomname = 'MULTIPOLYGON'
  246. elif geomtype_flatten == QgsWkbTypes.GeometryCollection:
  247. geomname = 'GEOMETRYCOLLECTION'
  248. elif geomtype_flatten == QgsWkbTypes.CircularString:
  249. geomname = 'CIRCULARSTRING'
  250. elif geomtype_flatten == QgsWkbTypes.CompoundCurve:
  251. geomname = 'COMPOUNDCURVE'
  252. elif geomtype_flatten == QgsWkbTypes.CurvePolygon:
  253. geomname = 'CURVEPOLYGON'
  254. elif geomtype_flatten == QgsWkbTypes.MultiCurve:
  255. geomname = 'MULTICURVE'
  256. elif geomtype_flatten == QgsWkbTypes.MultiSurface:
  257. geomname = 'MULTISURFACE'
  258. geomdim = 'XY'
  259. if QgsWkbTypes.hasZ(geom_type.wkbType):
  260. geomdim += 'Z'
  261. if QgsWkbTypes.hasM(geom_type.wkbType):
  262. geomdim += 'M'
  263. item = [
  264. Table.VectorType,
  265. table.tableName(),
  266. bool(table.flags() & QgsAbstractDatabaseProviderConnection.View), # is_view
  267. table.tableName(),
  268. table.geometryColumn(),
  269. geomname,
  270. geomdim,
  271. srid
  272. ]
  273. self.mapSridToName[srid] = geom_type.crs.description()
  274. else:
  275. item = [
  276. Table.TableType,
  277. table.tableName(),
  278. bool(table.flags() & QgsAbstractDatabaseProviderConnection.View),
  279. ]
  280. items.append(item)
  281. return items
  282. def getRasterTables(self, schema=None):
  283. """ get list of table with a geometry column
  284. it returns:
  285. name (table name)
  286. type = 'view' (is a view?)
  287. geometry_column:
  288. r.table_name (the prefix table name, use this to load the layer)
  289. r.geometry_column
  290. srid
  291. """
  292. items = []
  293. for table in self.core_connection.tables(schema, QgsAbstractDatabaseProviderConnection.Raster):
  294. geom_type = table.geometryColumnTypes()[0]
  295. # Use integer PG code for SRID
  296. srid = geom_type.crs.postgisSrid()
  297. item = [
  298. Table.RasterType,
  299. table.tableName(),
  300. bool(table.flags() & QgsAbstractDatabaseProviderConnection.View),
  301. table.tableName(),
  302. table.geometryColumn(),
  303. srid,
  304. ]
  305. self.mapSridToName[srid] = geom_type.crs.description()
  306. items.append(item)
  307. return items
  308. def getTableRowCount(self, table):
  309. lyr = self.gdal_ds.GetLayerByName(self.getSchemaTableName(table)[1])
  310. return lyr.GetFeatureCount() if lyr is not None else None
  311. def getTableFields(self, table):
  312. """ return list of columns in table """
  313. sql = "PRAGMA table_info(%s)" % (self.quoteId(table))
  314. ret = self._fetchAll(sql)
  315. if ret is None:
  316. ret = []
  317. return ret
  318. def getTableIndexes(self, table):
  319. """ get info about table's indexes """
  320. sql = "PRAGMA index_list(%s)" % (self.quoteId(table))
  321. indexes = self._fetchAll(sql)
  322. if indexes is None:
  323. return []
  324. for i, idx in enumerate(indexes):
  325. # sqlite has changed the number of columns returned by index_list since 3.8.9
  326. # I am not using self.getInfo() here because this behavior
  327. # can be changed back without notice as done for index_info, see:
  328. # http://repo.or.cz/sqlite.git/commit/53555d6da78e52a430b1884b5971fef33e9ccca4
  329. if len(idx) == 3:
  330. num, name, unique = idx
  331. if len(idx) == 5:
  332. num, name, unique, createdby, partial = idx
  333. sql = "PRAGMA index_info(%s)" % (self.quoteId(name))
  334. idx = [num, name, unique]
  335. cols = [
  336. cid
  337. for seq, cid, cname in self._fetchAll(sql)
  338. ]
  339. idx.append(cols)
  340. indexes[i] = idx
  341. return indexes
  342. def getTableConstraints(self, table):
  343. return None
  344. def getTableTriggers(self, table):
  345. _, tablename = self.getSchemaTableName(table)
  346. # Do not list rtree related triggers as we don't want them to be dropped
  347. sql = "SELECT name, sql FROM sqlite_master WHERE tbl_name = %s AND type = 'trigger'" % (self.quoteString(tablename))
  348. if self.isVectorTable(table):
  349. sql += " AND name NOT LIKE 'rtree_%%'"
  350. elif self.isRasterTable(table):
  351. sql += " AND name NOT LIKE '%%_zoom_insert'"
  352. sql += " AND name NOT LIKE '%%_zoom_update'"
  353. sql += " AND name NOT LIKE '%%_tile_column_insert'"
  354. sql += " AND name NOT LIKE '%%_tile_column_update'"
  355. sql += " AND name NOT LIKE '%%_tile_row_insert'"
  356. sql += " AND name NOT LIKE '%%_tile_row_update'"
  357. return self._fetchAll(sql)
  358. def deleteTableTrigger(self, trigger, table=None):
  359. """Deletes trigger """
  360. sql = "DROP TRIGGER %s" % self.quoteId(trigger)
  361. self._execute_and_commit(sql)
  362. def getTableExtent(self, table, geom, force=False):
  363. """ find out table extent """
  364. _, tablename = self.getSchemaTableName(table)
  365. if self.isRasterTable(table):
  366. md = self.gdal_ds.GetMetadata('SUBDATASETS')
  367. if md is None or len(md) == 0:
  368. ds = self.gdal_ds
  369. else:
  370. subdataset_name = 'GPKG:%s:%s' % (self.gdal_ds.GetDescription(), tablename)
  371. ds = gdal.Open(subdataset_name)
  372. if ds is None:
  373. return None
  374. gt = ds.GetGeoTransform()
  375. minx = gt[0]
  376. maxx = gt[0] + gt[1] * ds.RasterYSize
  377. maxy = gt[3]
  378. miny = gt[3] + gt[5] * ds.RasterYSize
  379. return (minx, miny, maxx, maxy)
  380. lyr = self.gdal_ds.GetLayerByName(tablename)
  381. if lyr is None:
  382. return None
  383. ret = lyr.GetExtent(force=force, can_return_null=True)
  384. if ret is None:
  385. return None
  386. minx, maxx, miny, maxy = ret
  387. return (minx, miny, maxx, maxy)
  388. def getViewDefinition(self, view):
  389. """ returns definition of the view """
  390. return None
  391. def getSpatialRefInfo(self, srid):
  392. if srid in self.mapSridToName:
  393. return self.mapSridToName[srid]
  394. sql = "SELECT srs_name FROM gpkg_spatial_ref_sys WHERE srs_id = %s" % self.quoteString(srid)
  395. res = self._fetchOne(sql)
  396. if res is not None and len(res) > 0:
  397. res = res[0]
  398. self.mapSridToName[srid] = res
  399. return res
  400. def isVectorTable(self, table):
  401. _, tablename = self.getSchemaTableName(table)
  402. return self.gdal_ds.GetLayerByName(tablename) is not None
  403. def isRasterTable(self, table):
  404. if self.has_raster and not self.isVectorTable(table):
  405. _, tablename = self.getSchemaTableName(table)
  406. md = self.gdal_ds.GetMetadata('SUBDATASETS')
  407. if md is None or len(md) == 0:
  408. sql = "SELECT COUNT(*) FROM gpkg_contents WHERE data_type = 'tiles' AND table_name = %s" % self.quoteString(tablename)
  409. ret = self._fetchOne(sql)
  410. return ret != [] and ret[0][0] == 1
  411. else:
  412. subdataset_name = 'GPKG:%s:%s' % (self.gdal_ds.GetDescription(), tablename)
  413. for key in md:
  414. if md[key] == subdataset_name:
  415. return True
  416. return False
  417. def getOGRFieldTypeFromSQL(self, sql_type):
  418. ogr_type = ogr.OFTString
  419. ogr_subtype = ogr.OFSTNone
  420. width = 0
  421. if not sql_type.startswith('TEXT ('):
  422. pos = sql_type.find(' (')
  423. if pos >= 0:
  424. sql_type = sql_type[0:pos]
  425. if sql_type == 'BOOLEAN':
  426. ogr_type = ogr.OFTInteger
  427. ogr_subtype = ogr.OFSTBoolean
  428. elif sql_type in ('TINYINT', 'SMALLINT', 'MEDIUMINT'):
  429. ogr_type = ogr.OFTInteger
  430. elif sql_type == 'INTEGER':
  431. ogr_type = ogr.OFTInteger64
  432. elif sql_type == 'FLOAT':
  433. ogr_type = ogr.OFTReal
  434. ogr_subtype = ogr.OFSTFloat32
  435. elif sql_type == 'DOUBLE':
  436. ogr_type = ogr.OFTReal
  437. elif sql_type == 'DATE':
  438. ogr_type = ogr.OFTDate
  439. elif sql_type == 'DATETIME':
  440. ogr_type = ogr.OFTDateTime
  441. elif sql_type.startswith('TEXT (') and sql_type.endswith(')'):
  442. width = int(sql_type[len('TEXT ('):-1])
  443. return (ogr_type, ogr_subtype, width)
  444. def createOGRFieldDefnFromSQL(self, sql_fielddef):
  445. f_split = sql_fielddef.split(' ')
  446. quoted_name = f_split[0]
  447. name = self.unquoteId(quoted_name)
  448. sql_type = f_split[1].upper()
  449. if len(f_split) >= 3 and f_split[2].startswith('(') and f_split[2].endswith(')'):
  450. sql_type += ' ' + f_split[2]
  451. f_split = [f for f in f_split[3:]]
  452. else:
  453. f_split = [f for f in f_split[2:]]
  454. ogr_type, ogr_subtype, width = self.getOGRFieldTypeFromSQL(sql_type)
  455. fld_defn = ogr.FieldDefn(name, ogr_type)
  456. fld_defn.SetSubType(ogr_subtype)
  457. fld_defn.SetWidth(width)
  458. if len(f_split) >= 2 and f_split[0] == 'NOT' and f_split[1] == 'NULL':
  459. fld_defn.SetNullable(False)
  460. f_split = [f for f in f_split[2:]]
  461. elif len(f_split) >= 1:
  462. f_split = [f for f in f_split[1:]]
  463. if len(f_split) >= 2 and f_split[0] == 'DEFAULT':
  464. new_default = f_split[1]
  465. if new_default == '':
  466. fld_defn.SetDefault(None)
  467. elif new_default == 'NULL' or ogr_type in (ogr.OFTInteger, ogr.OFTReal):
  468. fld_defn.SetDefault(new_default)
  469. elif new_default.startswith("'") and new_default.endswith("'"):
  470. fld_defn.SetDefault(new_default)
  471. else:
  472. fld_defn.SetDefault(self.quoteString(new_default))
  473. return fld_defn
  474. def createTable(self, table, field_defs, pkey):
  475. """Creates ordinary table
  476. 'fields' is array containing field definitions
  477. 'pkey' is the primary key name
  478. """
  479. if len(field_defs) == 0:
  480. return False
  481. options = []
  482. if pkey is not None and pkey != "":
  483. options += ['FID=' + pkey]
  484. _, tablename = self.getSchemaTableName(table)
  485. lyr = self.gdal_ds.CreateLayer(tablename, geom_type=ogr.wkbNone, options=options)
  486. if lyr is None:
  487. return False
  488. for field_def in field_defs:
  489. fld_defn = self.createOGRFieldDefnFromSQL(field_def)
  490. if fld_defn.GetName() == pkey:
  491. continue
  492. if lyr.CreateField(fld_defn) != 0:
  493. return False
  494. return True
  495. def deleteTable(self, table):
  496. """Deletes table from the database """
  497. if self.isRasterTable(table):
  498. return False
  499. _, tablename = self.getSchemaTableName(table)
  500. for i in range(self.gdal_ds.GetLayerCount()):
  501. if self.gdal_ds.GetLayer(i).GetName() == tablename:
  502. return self.gdal_ds.DeleteLayer(i) == 0
  503. return False
  504. def emptyTable(self, table):
  505. """Deletes all rows from table """
  506. if self.isRasterTable(table):
  507. return False
  508. sql = "DELETE FROM %s" % self.quoteId(table)
  509. self._execute_and_commit(sql)
  510. def renameTable(self, table, new_table):
  511. """Renames the table
  512. :param table: tuple with schema and table names
  513. :type table: tuple (str, str)
  514. :param new_table: new table name
  515. :type new_table: str
  516. :return: true on success
  517. :rtype: bool
  518. """
  519. try:
  520. name = table[1] # 0 is schema
  521. vector_table_names = [t.tableName() for t in self.core_connection.tables('', QgsAbstractDatabaseProviderConnection.Vector)]
  522. if name in vector_table_names:
  523. self.core_connection.renameVectorTable('', name, new_table)
  524. else:
  525. self.core_connection.renameRasterTable('', name, new_table)
  526. return True
  527. except QgsProviderConnectionException:
  528. return False
  529. def moveTable(self, table, new_table, new_schema=None):
  530. return self.renameTable(table, new_table)
  531. def runVacuum(self):
  532. """ run vacuum on the db """
  533. self._execute_and_commit("VACUUM")
  534. def addTableColumn(self, table, field_def):
  535. """Adds a column to table """
  536. _, tablename = self.getSchemaTableName(table)
  537. lyr = self.gdal_ds.GetLayerByName(tablename)
  538. if lyr is None:
  539. return False
  540. fld_defn = self.createOGRFieldDefnFromSQL(field_def)
  541. return lyr.CreateField(fld_defn) == 0
  542. def deleteTableColumn(self, table, column):
  543. """Deletes column from a table """
  544. if self.isGeometryColumn(table, column):
  545. return False
  546. _, tablename = self.getSchemaTableName(table)
  547. lyr = self.gdal_ds.GetLayerByName(tablename)
  548. if lyr is None:
  549. return False
  550. idx = lyr.GetLayerDefn().GetFieldIndex(column)
  551. if idx >= 0:
  552. return lyr.DeleteField(idx) == 0
  553. return False
  554. def updateTableColumn(self, table, column, new_name, new_data_type=None, new_not_null=None, new_default=None, comment=None):
  555. if self.isGeometryColumn(table, column):
  556. return False
  557. _, tablename = self.getSchemaTableName(table)
  558. lyr = self.gdal_ds.GetLayerByName(tablename)
  559. if lyr is None:
  560. return False
  561. if lyr.TestCapability(ogr.OLCAlterFieldDefn) == 0:
  562. return False
  563. idx = lyr.GetLayerDefn().GetFieldIndex(column)
  564. if idx >= 0:
  565. old_fielddefn = lyr.GetLayerDefn().GetFieldDefn(idx)
  566. flag = 0
  567. if new_name is not None:
  568. flag |= ogr.ALTER_NAME_FLAG
  569. else:
  570. new_name = column
  571. if new_data_type is None:
  572. ogr_type = old_fielddefn.GetType()
  573. ogr_subtype = old_fielddefn.GetSubType()
  574. width = old_fielddefn.GetWidth()
  575. else:
  576. flag |= ogr.ALTER_TYPE_FLAG
  577. flag |= ogr.ALTER_WIDTH_PRECISION_FLAG
  578. ogr_type, ogr_subtype, width = self.getOGRFieldTypeFromSQL(new_data_type)
  579. new_fielddefn = ogr.FieldDefn(new_name, ogr_type)
  580. new_fielddefn.SetSubType(ogr_subtype)
  581. new_fielddefn.SetWidth(width)
  582. if new_default is not None:
  583. flag |= ogr.ALTER_DEFAULT_FLAG
  584. if new_default == '':
  585. new_fielddefn.SetDefault(None)
  586. elif new_default == 'NULL' or ogr_type in (ogr.OFTInteger, ogr.OFTReal):
  587. new_fielddefn.SetDefault(str(new_default))
  588. elif new_default.startswith("'") and new_default.endswith("'"):
  589. new_fielddefn.SetDefault(str(new_default))
  590. else:
  591. new_fielddefn.SetDefault(self.quoteString(new_default))
  592. else:
  593. new_fielddefn.SetDefault(old_fielddefn.GetDefault())
  594. if new_not_null is not None:
  595. flag |= ogr.ALTER_NULLABLE_FLAG
  596. new_fielddefn.SetNullable(not new_not_null)
  597. else:
  598. new_fielddefn.SetNullable(old_fielddefn.IsNullable())
  599. return lyr.AlterFieldDefn(idx, new_fielddefn, flag) == 0
  600. return False
  601. def isGeometryColumn(self, table, column):
  602. _, tablename = self.getSchemaTableName(table)
  603. lyr = self.gdal_ds.GetLayerByName(tablename)
  604. if lyr is None:
  605. return False
  606. return column == lyr.GetGeometryColumn()
  607. def addGeometryColumn(self, table, geom_column='geometry', geom_type='POINT', srid=-1, dim=2):
  608. _, tablename = self.getSchemaTableName(table)
  609. lyr = self.gdal_ds.GetLayerByName(tablename)
  610. if lyr is None:
  611. return False
  612. ogr_type = ogr.wkbUnknown
  613. if geom_type == 'POINT':
  614. ogr_type = ogr.wkbPoint
  615. elif geom_type == 'LINESTRING':
  616. ogr_type = ogr.wkbLineString
  617. elif geom_type == 'POLYGON':
  618. ogr_type = ogr.wkbPolygon
  619. elif geom_type == 'MULTIPOINT':
  620. ogr_type = ogr.wkbMultiPoint
  621. elif geom_type == 'MULTILINESTRING':
  622. ogr_type = ogr.wkbMultiLineString
  623. elif geom_type == 'MULTIPOLYGON':
  624. ogr_type = ogr.wkbMultiPolygon
  625. elif geom_type == 'GEOMETRYCOLLECTION':
  626. ogr_type = ogr.wkbGeometryCollection
  627. if dim == 3:
  628. ogr_type = ogr_type | ogr.wkb25DBit
  629. elif dim == 4:
  630. if hasattr(ogr, 'GT_HasZ'):
  631. ogr_type = ogr.GT_SetZ(ogr_type)
  632. else:
  633. ogr_type = ogr_type | ogr.wkb25DBit
  634. if hasattr(ogr, 'GT_HasM'):
  635. ogr_type = ogr.GT_SetM(ogr_type)
  636. geom_field_defn = ogr.GeomFieldDefn(self.unquoteId(geom_column), ogr_type)
  637. if srid > 0:
  638. sr = osr.SpatialReference()
  639. if sr.ImportFromEPSG(srid) == 0:
  640. geom_field_defn.SetSpatialRef(sr)
  641. if lyr.CreateGeomField(geom_field_defn) != 0:
  642. return False
  643. self._opendb()
  644. return True
  645. def deleteGeometryColumn(self, table, geom_column):
  646. return False # not supported
  647. def addTableUniqueConstraint(self, table, column):
  648. """Adds a unique constraint to a table """
  649. return False # constraints not supported
  650. def deleteTableConstraint(self, table, constraint):
  651. """Deletes constraint in a table """
  652. return False # constraints not supported
  653. def addTablePrimaryKey(self, table, column):
  654. """Adds a primery key (with one column) to a table """
  655. sql = "ALTER TABLE %s ADD PRIMARY KEY (%s)" % (self.quoteId(table), self.quoteId(column))
  656. self._execute_and_commit(sql)
  657. def createTableIndex(self, table, name, column, unique=False):
  658. """Creates index on one column using default options """
  659. unique_str = "UNIQUE" if unique else ""
  660. sql = "CREATE %s INDEX %s ON %s (%s)" % (
  661. unique_str, self.quoteId(name), self.quoteId(table), self.quoteId(column))
  662. self._execute_and_commit(sql)
  663. def deleteTableIndex(self, table, name):
  664. schema, tablename = self.getSchemaTableName(table)
  665. sql = "DROP INDEX %s" % self.quoteId((schema, name))
  666. self._execute_and_commit(sql)
  667. def createSpatialIndex(self, table, geom_column):
  668. if self.isRasterTable(table):
  669. return False
  670. _, tablename = self.getSchemaTableName(table)
  671. sql = "SELECT CreateSpatialIndex(%s, %s)" % (
  672. self.quoteId(tablename), self.quoteId(geom_column))
  673. try:
  674. res = self._fetchOne(sql)
  675. except QgsProviderConnectionException:
  676. return False
  677. return res is not None and res[0][0] == 1
  678. def deleteSpatialIndex(self, table, geom_column):
  679. if self.isRasterTable(table):
  680. return False
  681. _, tablename = self.getSchemaTableName(table)
  682. sql = "SELECT DisableSpatialIndex(%s, %s)" % (
  683. self.quoteId(tablename), self.quoteId(geom_column))
  684. res = self._fetchOne(sql)
  685. return len(res) > 0 and len(res[0]) > 0 and res[0][0] == 1
  686. def hasSpatialIndex(self, table, geom_column):
  687. if self.isRasterTable(table) or geom_column is None:
  688. return False
  689. _, tablename = self.getSchemaTableName(table)
  690. # (only available in >= 2.1.2)
  691. sql = "SELECT HasSpatialIndex(%s, %s)" % (self.quoteString(tablename), self.quoteString(geom_column))
  692. gdal.PushErrorHandler()
  693. ret = self._fetchOne(sql)
  694. gdal.PopErrorHandler()
  695. if len(ret) == 0:
  696. # might be the case for GDAL < 2.1.2
  697. sql = "SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name LIKE %s" % self.quoteString("%%rtree_" + tablename + "_%%")
  698. ret = self._fetchOne(sql)
  699. if len(ret) == 0:
  700. return False
  701. else:
  702. return ret[0][0] >= 1
  703. def execution_error_types(self):
  704. return sqlite3.Error, sqlite3.ProgrammingError, sqlite3.Warning
  705. def connection_error_types(self):
  706. return sqlite3.InterfaceError, sqlite3.OperationalError
  707. def getSqlDictionary(self):
  708. from .sql_dictionary import getSqlDictionary
  709. sql_dict = getSqlDictionary()
  710. items = []
  711. for tbl in self.getTables():
  712. items.append(tbl[1]) # table name
  713. for fld in self.getTableFields(tbl[0]):
  714. items.append(fld[1]) # field name
  715. sql_dict["identifier"] = items
  716. return sql_dict
  717. def getQueryBuilderDictionary(self):
  718. from .sql_dictionary import getQueryBuilderDictionary
  719. return getQueryBuilderDictionary()