Skip to content

Commit

Permalink
Merge pull request #623 from holistics/add-snowflake-connector
Browse files Browse the repository at this point in the history
Add Snowflake connector
  • Loading branch information
huyphung1602 authored Aug 27, 2024
2 parents 031a351 + 61e30d4 commit 17f38c8
Show file tree
Hide file tree
Showing 13 changed files with 2,444 additions and 177 deletions.
4 changes: 2 additions & 2 deletions packages/dbml-cli/src/cli/connector.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import logger from '../helpers/logger';

export default async function connectionHandler (program) {
try {
const { connection, format } = getConnectionOpt(program.args);
const { connection, databaseType } = getConnectionOpt(program.args);
const opts = program.opts();
const schemaJson = await connector.fetchSchemaJson(connection, format);
const schemaJson = await connector.fetchSchemaJson(connection, databaseType);

if (!opts.outFile && !opts.outDir) {
const res = importer.generateDbml(schemaJson);
Expand Down
12 changes: 7 additions & 5 deletions packages/dbml-cli/src/cli/index.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* eslint-disable max-len */
import program from 'commander';
import importHandler from './import';
import exportHandler from './export';
Expand Down Expand Up @@ -56,14 +57,15 @@ function db2dbml (args) {
// - postgres: postgresql://user:password@localhost:5432/dbname
// - mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate
const description = `
<format> your database format (postgres, mysql, mssql)
<database-type> your database format (postgres, mysql, mssql, snowflake)
<connection-string> your database connection string:
- postgres: postgresql://user:password@localhost:5432/dbname
- mysql: mysql://user:password@localhost:3306/dbname
- mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate=true;'
- postgres: 'postgresql://user:password@localhost:5432/dbname?schemas=schema1,schema2,schema3'
- mysql: 'mysql://user:password@localhost:3306/dbname'
- mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate=true;Schemas=schema1,schema2,schema3;'
- snowflake: 'SERVER=<account_identifier>.<region>;UID=<your_username>;PWD=<your_password>;DATABASE=<your_database>;WAREHOUSE=<your_warehouse>;ROLE=<your_role>;SCHEMAS=schema1,schema2,schema3;'
`;
program
.usage('<format> <connection-string> [options]')
.usage('<database-type> <connection-string> [options]')
.description(description)
.option('-o, --out-file <pathspec>', 'compile all input files into a single files');

Expand Down
6 changes: 3 additions & 3 deletions packages/dbml-cli/src/cli/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ function getFormatOpt (opts) {
}

function getConnectionOpt (args) {
const supportedDatabases = ['postgres', 'mysql', 'mssql'];
const supportedDatabases = ['postgres', 'mysql', 'mssql', 'snowflake'];
const defaultConnectionOpt = {
connection: args[0],
format: 'unknown',
databaseType: 'unknown',
};

return reduce(args, (connectionOpt, arg) => {
if (supportedDatabases.includes(arg)) connectionOpt.format = arg;
if (supportedDatabases.includes(arg)) connectionOpt.databaseType = arg;
// Check if the arg is a connection string using regex
const connectionStringRegex = /^.*[:;]/;
if (connectionStringRegex.test(arg)) {
Expand Down
3 changes: 2 additions & 1 deletion packages/dbml-connector/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"dependencies": {
"mssql": "^11.0.1",
"mysql2": "^3.11.0",
"pg": "^8.12.0"
"pg": "^8.12.0",
"snowflake-sdk": "^1.12.0"
},
"engines": {
"node": ">=18"
Expand Down
9 changes: 6 additions & 3 deletions packages/dbml-connector/src/connectors/connector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ import { DatabaseSchema } from './types';
import { fetchSchemaJson as fetchPostgresSchemaJson } from './postgresConnector';
import { fetchSchemaJson as fetchMssqlSchemaJson } from './mssqlConnector';
import { fetchSchemaJson as fetchMysqlSchemaJson } from './mysqlConnector';
import { fetchSchemaJson as fetchSnowflakeSchemaJson } from './snowflakeConnector';

const fetchSchemaJson = async (connection: string, format: string): Promise<DatabaseSchema> => {
switch (format) {
const fetchSchemaJson = async (connection: string, databaseType: string): Promise<DatabaseSchema> => {
switch (databaseType) {
case 'postgres':
return fetchPostgresSchemaJson(connection);
case 'mssql':
return fetchMssqlSchemaJson(connection);
case 'mysql':
return fetchMysqlSchemaJson(connection);
case 'snowflake':
return fetchSnowflakeSchemaJson(connection);
default:
throw new Error(`Unsupported connection format: ${format}`);
throw new Error(`Unsupported database type: ${databaseType}`);
}
};

Expand Down
150 changes: 76 additions & 74 deletions packages/dbml-connector/src/connectors/mssqlConnector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
DatabaseSchema,
DefaultInfo,
} from './types';
import { buildSchemaQuery, parseConnectionString } from '../utils/parseSchema';

const MSSQL_DATE_TYPES = [
'date',
Expand Down Expand Up @@ -132,83 +133,79 @@ const generateField = (row: Record<string, any>): Field => {
};
};

const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool): Promise<{
const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool, schemas: string[]): Promise<{
tables: Table[],
fields: FieldsDictionary,
enums: Enum[],
}> => {
const fields: FieldsDictionary = {};
const enums: Enum[] = [];
const tablesAndFieldsSql = `
WITH tables_and_fields AS (
SELECT
s.name AS table_schema,
t.name AS table_name,
c.name AS column_name,
ty.name AS data_type,
c.max_length AS character_maximum_length,
c.precision AS numeric_precision,
c.scale AS numeric_scale,
c.is_identity AS identity_increment,
CASE
WHEN c.is_nullable = 1 THEN 'YES'
ELSE 'NO'
END AS is_nullable,
CASE
WHEN c.default_object_id = 0 THEN NULL
ELSE OBJECT_DEFINITION(c.default_object_id)
END AS column_default,
-- Fetching table comments
p.value AS table_comment,
ep.value AS column_comment
FROM
sys.tables t
JOIN
sys.schemas s ON t.schema_id = s.schema_id
JOIN
sys.columns c ON t.object_id = c.object_id
JOIN
sys.types ty ON c.user_type_id = ty.user_type_id
LEFT JOIN
sys.extended_properties p ON p.major_id = t.object_id
AND p.name = 'MS_Description'
AND p.minor_id = 0 -- Ensure minor_id is 0 for table comments
LEFT JOIN
sys.extended_properties ep ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE
t.type = 'U' -- User-defined tables
)
WITH tables_and_fields AS (
SELECT
tf.table_schema,
tf.table_name,
tf.column_name,
tf.data_type,
tf.character_maximum_length,
tf.numeric_precision,
tf.numeric_scale,
tf.identity_increment,
tf.is_nullable,
tf.column_default,
tf.table_comment,
tf.column_comment,
cc.name AS check_constraint_name, -- Adding CHECK constraint name
cc.definition AS check_constraint_definition, -- Adding CHECK constraint definition
CASE
WHEN tf.column_default LIKE '((%))' THEN 'number'
WHEN tf.column_default LIKE '(''%'')' THEN 'string'
ELSE 'expression'
END AS default_type
s.name AS table_schema,
t.name AS table_name,
c.name AS column_name,
ty.name AS data_type,
c.max_length AS character_maximum_length,
c.precision AS numeric_precision,
c.scale AS numeric_scale,
c.is_identity AS identity_increment,
CASE
WHEN c.is_nullable = 1 THEN 'YES'
ELSE 'NO'
END AS is_nullable,
CASE
WHEN c.default_object_id = 0 THEN NULL
ELSE OBJECT_DEFINITION(c.default_object_id)
END AS column_default,
-- Fetching table comments
p.value AS table_comment,
ep.value AS column_comment
FROM
tables_and_fields AS tf
LEFT JOIN
sys.check_constraints cc ON cc.parent_object_id = OBJECT_ID(tf.table_schema + '.' + tf.table_name)
AND cc.definition LIKE '%' + tf.column_name + '%' -- Ensure the constraint references the column
ORDER BY
tf.table_schema,
tf.table_name,
tf.column_name;
sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.columns c ON t.object_id = c.object_id
JOIN sys.types ty ON c.user_type_id = ty.user_type_id
LEFT JOIN sys.extended_properties p ON p.major_id = t.object_id
AND p.name = 'MS_Description'
AND p.minor_id = 0 -- Ensure minor_id is 0 for table comments
LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE
t.type = 'U' -- User-defined tables
)
SELECT
tf.table_schema,
tf.table_name,
tf.column_name,
tf.data_type,
tf.character_maximum_length,
tf.numeric_precision,
tf.numeric_scale,
tf.identity_increment,
tf.is_nullable,
tf.column_default,
tf.table_comment,
tf.column_comment,
cc.name AS check_constraint_name, -- Adding CHECK constraint name
cc.definition AS check_constraint_definition, -- Adding CHECK constraint definition
CASE
WHEN tf.column_default LIKE '((%))' THEN 'number'
WHEN tf.column_default LIKE '(''%'')' THEN 'string'
ELSE 'expression'
END AS default_type
FROM
tables_and_fields AS tf
LEFT JOIN sys.check_constraints cc
ON cc.parent_object_id = OBJECT_ID(tf.table_schema + '.' + tf.table_name)
AND cc.definition LIKE '%' + tf.column_name + '%' -- Ensure the constraint references the column
${buildSchemaQuery('tf.table_schema', schemas, 'WHERE')}
ORDER BY
tf.table_schema,
tf.table_name,
tf.column_name;
`;

const tablesAndFieldsResult = await client.query(tablesAndFieldsSql);
Expand Down Expand Up @@ -259,7 +256,7 @@ const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool): Promise
};
};

const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
const generateRefs = async (client: sql.ConnectionPool, schemas: string[]): Promise<Ref[]> => {
const refs: Ref[] = [];

const refsListSql = `
Expand Down Expand Up @@ -290,6 +287,7 @@ const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
JOIN sys.tables AS t2 ON fk.referenced_object_id = t2.object_id
JOIN sys.schemas AS s2 ON t2.schema_id = s2.schema_id
WHERE s.name NOT IN ('sys', 'information_schema')
${buildSchemaQuery('s.name', schemas)}
ORDER BY
s.name,
t.name;
Expand Down Expand Up @@ -334,7 +332,7 @@ const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
return refs;
};

const generateIndexes = async (client: sql.ConnectionPool) => {
const generateIndexes = async (client: sql.ConnectionPool, schemas: string[]) => {
const indexListSql = `
WITH user_tables AS (
SELECT
Expand All @@ -352,6 +350,7 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
),
index_info AS (
SELECT
SCHEMA_NAME(t.schema_id) AS table_schema, -- Add schema information
OBJECT_NAME(i.object_id) AS table_name,
i.name AS index_name,
i.is_unique,
Expand Down Expand Up @@ -399,8 +398,10 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
user_tables ut
LEFT JOIN
index_info ii ON ut.TABLE_NAME = ii.table_name
AND ut.TABLE_SCHEMA = ii.table_schema
WHERE
ii.columns IS NOT NULL
${buildSchemaQuery('ut.TABLE_SCHEMA', schemas)}
ORDER BY
ut.TABLE_NAME,
ii.constraint_type,
Expand Down Expand Up @@ -491,11 +492,12 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
};

const fetchSchemaJson = async (connection: string): Promise<DatabaseSchema> => {
const client = await getValidatedClient(connection);
const { connectionString, schemas } = parseConnectionString(connection, 'odbc');
const client = await getValidatedClient(connectionString);

const tablesFieldsAndEnumsRes = generateTablesFieldsAndEnums(client);
const indexesRes = generateIndexes(client);
const refsRes = generateRefs(client);
const tablesFieldsAndEnumsRes = generateTablesFieldsAndEnums(client, schemas);
const indexesRes = generateIndexes(client, schemas);
const refsRes = generateRefs(client, schemas);

const res = await Promise.all([
tablesFieldsAndEnumsRes,
Expand Down
Loading

0 comments on commit 17f38c8

Please sign in to comment.