diff --git a/mssql/features.py b/mssql/features.py index faaa0bbb..a27e90b7 100644 --- a/mssql/features.py +++ b/mssql/features.py @@ -33,6 +33,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): requires_literal_defaults = True requires_sqlparse_for_splitting = False supports_boolean_expr_in_select_clause = False + supports_comments = True supports_covering_indexes = True supports_deferrable_unique_constraints = False supports_expression_indexes = False diff --git a/mssql/introspection.py b/mssql/introspection.py index 5f33bff9..efba52ad 100644 --- a/mssql/introspection.py +++ b/mssql/introspection.py @@ -4,10 +4,12 @@ from django.db import DatabaseError import pyodbc as Database +from collections import namedtuple + from django import VERSION -from django.db.backends.base.introspection import ( - BaseDatabaseIntrospection, FieldInfo, TableInfo, -) +from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo +from django.db.backends.base.introspection import TableInfo as BaseTableInfo from django.db.models.indexes import Index from django.conf import settings @@ -16,6 +18,8 @@ SQL_SMALLAUTOFIELD = -777333 SQL_TIMESTAMP_WITH_TIMEZONE = -155 +FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("comment",)) +TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) def get_schema_name(): return getattr(settings, 'SCHEMA_TO_INSPECT', 'SCHEMA_NAME()') @@ -73,13 +77,26 @@ def get_table_list(self, cursor): """ Returns a list of table and view names in the current database. """ - sql = 'SELECT TABLE_NAME, TABLE_TYPE FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = %s' % ( + sql = """SELECT + TABLE_NAME, + TABLE_TYPE, + CAST(ep.value AS VARCHAR) AS COMMENT + FROM INFORMATION_SCHEMA.TABLES i + LEFT JOIN sys.tables t ON t.name = i.TABLE_NAME + LEFT JOIN sys.extended_properties ep ON t.object_id = ep.major_id + AND ((ep.name = 'MS_DESCRIPTION' AND ep.minor_id = 0) OR ep.value IS NULL) + AND i.TABLE_SCHEMA = %s""" % ( get_schema_name()) cursor.execute(sql) types = {'BASE TABLE': 't', 'VIEW': 'v'} - return [TableInfo(row[0], types.get(row[1])) - for row in cursor.fetchall() - if row[0] not in self.ignored_tables] + if VERSION >= (4, 2): + return [TableInfo(row[0], types.get(row[1]), row[2]) + for row in cursor.fetchall() + if row[0] not in self.ignored_tables] + else: + return [BaseTableInfo(row[0], types.get(row[1])) + for row in cursor.fetchall() + if row[0] not in self.ignored_tables] def _is_auto_field(self, cursor, table_name, column_name): """ @@ -113,7 +130,7 @@ def get_table_description(self, cursor, table_name, identity_check=True): if not columns: raise DatabaseError(f"Table {table_name} does not exist.") - + items = [] for column in columns: if VERSION >= (3, 2): @@ -128,7 +145,16 @@ def get_table_description(self, cursor, table_name, identity_check=True): column.append(collation_name[0] if collation_name else '') else: column.append('') - + if VERSION >= (4, 2): + sql = """select CAST(ep.value AS VARCHAR) AS COMMENT + FROM sys.columns c + INNER JOIN sys.tables t ON c.object_id = t.object_id + INNER JOIN sys.extended_properties ep ON c.object_id=ep.major_id AND ep.minor_id = c.column_id + WHERE t.name = '%s' AND c.name = '%s' AND ep.name = 'MS_Description' + """ % (table_name, column[0]) + cursor.execute(sql) + comment = cursor.fetchone() + column.append(comment[0] if comment else '') if identity_check and self._is_auto_field(cursor, table_name, column[0]): if column[1] == Database.SQL_BIGINT: column[1] = SQL_BIGAUTOFIELD @@ -138,7 +164,10 @@ def get_table_description(self, cursor, table_name, identity_check=True): column[1] = SQL_AUTOFIELD if column[1] == Database.SQL_WVARCHAR and column[3] < 4000: column[1] = Database.SQL_WCHAR - items.append(FieldInfo(*column)) + if VERSION >= (4, 2): + items.append(FieldInfo(*column)) + else: + items.append(BaseFieldInfo(*column)) return items def get_sequences(self, cursor, table_name, table_fields=()): diff --git a/mssql/schema.py b/mssql/schema.py index 0785d761..771b93b6 100644 --- a/mssql/schema.py +++ b/mssql/schema.py @@ -93,7 +93,40 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_rename_table = "EXEC sp_rename %(old_table)s, %(new_table)s" sql_create_unique_null = "CREATE UNIQUE INDEX %(name)s ON %(table)s(%(columns)s) " \ "WHERE %(columns)s IS NOT NULL" - + sql_alter_table_comment= """ + IF NOT EXISTS (SELECT NULL FROM sys.extended_properties ep + WHERE ep.major_id = OBJECT_ID('%(table)s') + AND ep.name = 'MS_Description' + AND ep.minor_id = 0) + EXECUTE sp_addextendedproperty + @name = 'MS_Description', @value = %(comment)s, + @level0type = 'SCHEMA', @level0name = 'dbo', + @level1type = 'TABLE', @level1name = %(table)s + ELSE + EXECUTE sp_updateextendedproperty + @name = 'MS_Description', @value = %(comment)s, + @level0type = 'SCHEMA', @level0name = 'dbo', + @level1type = 'TABLE', @level1name = %(table)s + """ + sql_alter_column_comment= """ + IF NOT EXISTS (SELECT NULL FROM sys.extended_properties ep + WHERE ep.major_id = OBJECT_ID('%(table)s') + AND ep.name = 'MS_Description' + AND ep.minor_id = (SELECT column_id FROM sys.columns + WHERE name = '%(column)s' + AND object_id = OBJECT_ID('%(table)s'))) + EXECUTE sp_addextendedproperty + @name = 'MS_Description', @value = %(comment)s, + @level0type = 'SCHEMA', @level0name = 'dbo', + @level1type = 'TABLE', @level1name = %(table)s, + @level2type = 'COLUMN', @level2name = %(column)s + ELSE + EXECUTE sp_updateextendedproperty + @name = 'MS_Description', @value = %(comment)s, + @level0type = 'SCHEMA', @level0name = 'dbo', + @level1type = 'TABLE', @level1name = %(table)s, + @level2type = 'COLUMN', @level2name = %(column)s + """ _deferred_unique_indexes = defaultdict(list) def _alter_column_default_sql(self, model, old_field, new_field, drop=False): @@ -138,7 +171,18 @@ def _alter_column_default_sql(self, model, old_field, new_field, drop=False): }, params, ) - + + def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment): + return ( + self.sql_alter_column_comment + % { + "table": self.quote_name(model._meta.db_table), + "column": new_field.column, + "comment": self._comment_sql(new_db_comment), + }, + [], + ) + def _alter_column_null_sql(self, model, old_field, new_field): """ Hook to specialize column null alteration. @@ -316,7 +360,19 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, # Drop any FK constraints, we'll remake them later fks_dropped = set() - if old_field.remote_field and old_field.db_constraint: + if ( + old_field.remote_field + and old_field.db_constraint + and (django_version < (4,2) + or + (django_version >= (4, 2) + and self._field_should_be_altered( + old_field, + new_field, + ignore={"db_comment"}) + ) + ) + ): # Drop index, SQL Server requires explicit deletion if not hasattr(new_field, 'db_constraint') or not new_field.db_constraint: index_names = self._constraint_names(model, [old_field.column], index=True) @@ -446,8 +502,11 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, actions = [] null_actions = [] post_actions = [] - # Type change? - if old_type != new_type: + # Type or comment change? + if old_type != new_type or (django_version >= (4, 2) and + self.connection.features.supports_comments + and old_field.db_comment != new_field.db_comment + ): if django_version >= (4, 2): fragment, other_actions = self._alter_column_type_sql( model, old_field, new_field, new_type, old_collation=None, new_collation=None @@ -922,6 +981,19 @@ def add_field(self, model, field): "changes": changes_sql, } self.execute(sql, params) + # Add field comment, if required. + if django_version >= (4, 2): + if ( + field.db_comment + and self.connection.features.supports_comments + and not self.connection.features.supports_comments_inline + ): + field_type = db_params["type"] + self.execute( + *self._alter_column_comment_sql( + model, field, field_type, field.db_comment + ) + ) # Add an index, if required self.deferred_sql.extend(self._field_indexes_sql(model, field)) # Add any FK constraints later @@ -1129,6 +1201,23 @@ def create_model(self, model): # Prevent using [] as params, in the case a literal '%' is used in the definition self.execute(sql, params or None) + if django_version >= (4, 2) and self.connection.features.supports_comments: + # Add table comment. + if model._meta.db_table_comment: + self.alter_db_table_comment(model, None, model._meta.db_table_comment) + # Add column comments. + if not self.connection.features.supports_comments_inline: + for field in model._meta.local_fields: + if field.db_comment: + field_db_params = field.db_parameters( + connection=self.connection + ) + field_type = field_db_params["type"] + self.execute( + *self._alter_column_comment_sql( + model, field, field_type, field.db_comment + ) + ) # Add any field index and index_together's (deferred as SQLite3 _remake_table needs it) self.deferred_sql.extend(self._model_indexes_sql(model)) self.deferred_sql = list(set(self.deferred_sql))