Skip to content

Commit

Permalink
Support firewall rule name + fix firewall rule issues (#17804)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra authored Sep 18, 2023
1 parent 9276a2f commit 3b0a3a8
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 81 deletions.
14 changes: 10 additions & 4 deletions localization/xliff/enu/constants/localizedConstants.enu.xlf
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@
<source xml:lang="en">Database name</source>
</trans-unit>
<trans-unit id="startIpAddressPrompt">
<source xml:lang="en">Start IP</source>
<source xml:lang="en">Start IP Address</source>
</trans-unit>
<trans-unit id="endIpAddressPrompt">
<source xml:lang="en">End IP</source>
<source xml:lang="en">End IP Address</source>
</trans-unit>
<trans-unit id="firewallRuleNamePrompt">
<source xml:lang="en">Firewall rule name</source>
</trans-unit>
<trans-unit id="databasePlaceholder">
<source xml:lang="en">[Optional] Database to connect (press Enter to connect to &lt;default&gt; database)</source>
Expand Down Expand Up @@ -351,7 +354,7 @@
<source xml:lang="en">Your client IP address does not have access to the server. Add an Azure account and create a new firewall rule to enable access.</source>
</trans-unit>
<trans-unit id="msgPromptRetryFirewallRuleSignedIn">
<source xml:lang="en">Account signed In. Create new firewall rule? </source>
<source xml:lang="en">Your client IP Address '{0}' does not have access to the server '{1}' you're attempting to connect to. Would you like to create new firewall rule?</source>
</trans-unit>
<trans-unit id="msgPromptRetryFirewallRuleAdded">
<source xml:lang="en">Firewall rule successfully added. Retry profile creation? </source>
Expand Down Expand Up @@ -387,7 +390,10 @@
<source xml:lang="en">Run Query History</source>
</trans-unit>
<trans-unit id="msgInvalidIpAddress">
<source xml:lang="en">Invalid IP Address </source>
<source xml:lang="en">Invalid IP Address</source>
</trans-unit>
<trans-unit id="msgInvalidRuleName">
<source xml:lang="en">Invalid Firewall rule name</source>
</trans-unit>
<trans-unit id="msgNoQueriesAvailable">
<source xml:lang="en">No Queries Available</source>
Expand Down
16 changes: 11 additions & 5 deletions src/azure/accountService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,23 @@ export class AccountService {
return account;
}

public async createSecurityTokenMapping(): Promise<any> {
/**
* Creates access token mappings for user selected account and tenant.
* @param account User account to fetch tokens for.
* @param tenantId Tenant Id for which refresh token is needed
* @returns Security token mappings
*/
public async createSecurityTokenMapping(account: IAccount, tenantId: string): Promise<any> {
// TODO: match type for mapping in mssql and sqltoolsservice
let mapping = {};
mapping[this.getHomeTenant(this.account).id] = {
token: (await this.refreshToken(this.account)).token
mapping[tenantId] = {
token: (await this.refreshToken(account, tenantId)).token
};
return mapping;
}

public async refreshToken(account: IAccount): Promise<IToken> {
return await this._azureController.refreshAccessToken(account, this._accountStore, undefined, providerSettings.resources.azureManagementResource);
public async refreshToken(account: IAccount, tenantId: string): Promise<IToken> {
return await this._azureController.refreshAccessToken(account, this._accountStore, tenantId, providerSettings.resources.azureManagementResource);
}

public getHomeTenant(account: IAccount): ITenant {
Expand Down
12 changes: 11 additions & 1 deletion src/constants/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ export const azureAccountExtensionId = 'ms-vscode.azure-account';
export const databaseString = 'Database';
export const localizedTexts = 'localizedTexts';
export const ipAddressRegex = /\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b/;
/**
* Azure Firewall rule name convention is specified here:
* https://azure.github.io/PSRule.Rules.Azure/en/rules/Azure.Firewall.Name/
* When naming Azure resources, resource names must meet service requirements. The requirements for Firewall names are:
* - Between 1 and 80 characters long.
* - Alphanumerics, underscores, periods, and hyphens.
* - Start with alphanumeric.
* - End alphanumeric or underscore.
* - Firewall names must be unique within a resource group (we can't do string validation for this, so this is ignored)
*/
export const ruleNameRegex = /^[a-zA-Z0-9][a-zA-Z0-9_.-]{0,78}[a-zA-Z0-9_]?$/;
export const configAzureAccount = 'azureAccount';
export const azureAccountProviderCredentials = 'azureAccountProviderCredentials';
export const msalCacheFileName = 'accessTokenCache';
Expand Down Expand Up @@ -164,7 +175,6 @@ export const sqlToolsServiceConfigKey = 'service';
export const v1SqlToolsServiceConfigKey = 'v1Service';
export const scriptSelectText = 'SELECT TOP (1000) * FROM ';
export const tenantDisplayName = 'Microsoft';
export const firewallErrorMessage = 'To enable access, use the Windows Azure Management Portal or run sp_set_firewall_rule on the master database to create a firewall rule for this IP address or address range.';
export const windowsResourceClientPath = 'SqlToolsResourceProviderService.exe';
export const unixResourceClientPath = 'SqlToolsResourceProviderService';

Expand Down
24 changes: 9 additions & 15 deletions src/controllers/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -451,27 +451,21 @@ export default class ConnectionManager {
// Check if the error is an expired password
if (result.errorNumber === Constants.errorPasswordExpired || result.errorNumber === Constants.errorPasswordNeedsReset) {
// TODO: we should allow the user to change their password here once corefx supports SqlConnection.ChangePassword()
Utils.showErrorMsg(Utils.formatString(LocalizedConstants.msgConnectionErrorPasswordExpired,
result.errorNumber, result.errorMessage));
connection.errorNumber = result.errorNumber;
connection.errorMessage = result.errorMessage;
} else if (result.errorNumber === Constants.errorSSLCertificateValidationFailed) {
Utils.showErrorMsg(Utils.formatString(LocalizedConstants.msgConnectionErrorPasswordExpired, result.errorNumber, result.errorMessage));
} else if (result.errorNumber === Constants.errorSSLCertificateValidationFailed) { // check if it's an SSL failed error
this._failedUriToSSLMap.set(fileUri, result.errorMessage);
connection.errorNumber = result.errorNumber;
connection.errorMessage = result.errorMessage;
} else if (result.errorNumber !== Constants.errorLoginFailed) {
Utils.showErrorMsg(Utils.formatString(LocalizedConstants.msgConnectionError, result.errorNumber, result.errorMessage));
// check whether it's a firewall rule error
} else if (result.errorNumber === Constants.errorFirewallRule) { // check whether it's a firewall rule error
let firewallResult = await this.firewallService.handleFirewallRule(result.errorNumber, result.errorMessage);
if (firewallResult.result && firewallResult.ipAddress) {
this._failedUriToFirewallIpMap.set(fileUri, firewallResult.ipAddress);
this.failedUriToFirewallIpMap.set(fileUri, firewallResult.ipAddress);
} else {
Utils.showErrorMsg(Utils.formatString(LocalizedConstants.msgConnectionError, result.errorNumber, result.errorMessage));
}
connection.errorNumber = result.errorNumber;
connection.errorMessage = result.errorMessage;
} else {
connection.errorNumber = result.errorNumber;
connection.errorMessage = result.errorMessage;
Utils.showErrorMsg(Utils.formatString(LocalizedConstants.msgConnectionError, result.errorNumber, result.errorMessage));
}
connection.errorNumber = result.errorNumber;
connection.errorMessage = result.errorMessage;
} else {
const platformInfo = await PlatformInformation.getCurrent();
if (!platformInfo.isWindows && result.errorMessage && result.errorMessage.includes('Kerberos')) {
Expand Down
14 changes: 1 addition & 13 deletions src/firewall/firewallService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,7 @@ export class FirewallService {
private accountService: AccountService
) { }

private async asCreateFirewallRuleParams(serverName: string, startIpAddress: string, endIpAddress?: string): Promise<ICreateFirewallRuleParams> {
let params: ICreateFirewallRuleParams = {
account: this.accountService.account,
serverName: serverName,
startIpAddress: startIpAddress,
endIpAddress: endIpAddress ? endIpAddress : startIpAddress,
securityTokenMappings: await this.accountService.createSecurityTokenMapping()
};
return params;
}

public async createFirewallRule(serverName: string, startIpAddress: string, endIpAddress?: string): Promise<ICreateFirewallRuleResponse> {
let params = await this.asCreateFirewallRuleParams(serverName, startIpAddress, endIpAddress);
public async createFirewallRule(params: ICreateFirewallRuleParams): Promise<ICreateFirewallRuleResponse> {
let result = await this.accountService.client.sendResourceRequest(CreateFirewallRuleRequest.type, params);
return result;
}
Expand Down
1 change: 1 addition & 0 deletions src/models/contracts/firewall/firewallRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface ICreateFirewallRuleParams {
serverName: string;
startIpAddress: string;
endIpAddress: string;
firewallRuleName: string;
securityTokenMappings: {};
}

Expand Down
23 changes: 11 additions & 12 deletions src/objectExplorer/objectExplorerService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ export class ObjectExplorerService {
async updatedProfile => {
self.reconnectProfile(self._currentNode, updatedProfile);
});
} else if (ObjectExplorerUtils.isFirewallError(result.errorNumber)) {
// handle session failure because of firewall issue
let handleFirewallResult = await self._connectionManager.firewallService.handleFirewallRule
(Constants.errorFirewallRule, result.errorMessage);
if (handleFirewallResult.result && handleFirewallResult.ipAddress) {
const nodeUri = ObjectExplorerUtils.getNodeUri(self._currentNode);
const profile = <IConnectionProfile>self._currentNode.connectionInfo;
self.updateNode(self._currentNode);
self._connectionManager.connectionUI.handleFirewallError(nodeUri, profile, handleFirewallResult.ipAddress);
}
} else if (self._currentNode.connectionInfo.authenticationType === Constants.azureMfa
&& self.needsAccountRefresh(result, self._currentNode.connectionInfo.user)) {
let profile = self._currentNode.connectionInfo;
Expand All @@ -152,17 +162,6 @@ export class ObjectExplorerService {
}
const promise = self._sessionIdToPromiseMap.get(result.sessionId);

// handle session failure because of firewall issue
if (ObjectExplorerUtils.isFirewallError(result.errorMessage)) {
let handleFirewallResult = await self._connectionManager.firewallService.handleFirewallRule
(Constants.errorFirewallRule, result.errorMessage);
if (handleFirewallResult.result && handleFirewallResult.ipAddress) {
const nodeUri = ObjectExplorerUtils.getNodeUri(self._currentNode);
const profile = <IConnectionProfile>self._currentNode.connectionInfo;
self.updateNode(self._currentNode);
self._connectionManager.connectionUI.handleFirewallError(nodeUri, profile, handleFirewallResult.ipAddress);
}
}
if (promise) {
return promise.resolve(undefined);
}
Expand Down Expand Up @@ -232,7 +231,7 @@ export class ObjectExplorerService {
TelemetryActions.ExpandNode,
{
nodeType: parentNode.nodeType,
isErrored : (!!result.errorMessage).toString()
isErrored: (!!result.errorMessage).toString()
},
{
nodeCount: result?.nodes.length ?? 0
Expand Down
4 changes: 2 additions & 2 deletions src/objectExplorer/objectExplorerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export class ObjectExplorerUtils {
return LocalizedConstants.defaultDatabaseLabel;
}

public static isFirewallError(errorMessage: string): boolean {
return errorMessage.includes(Constants.firewallErrorMessage);
public static isFirewallError(errorCode: number): boolean {
return errorCode === Constants.errorFirewallRule;
}
}
109 changes: 91 additions & 18 deletions src/views/connectionUI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/

import * as vscode from 'vscode';
import { IAccount, IConnectionInfo } from 'vscode-mssql';
import { IAccount, IConnectionInfo, ITenant } from 'vscode-mssql';
import { AccountStore } from '../azure/accountStore';
import providerSettings from '../azure/providerSettings';
import * as constants from '../constants/constants';
Expand All @@ -14,7 +14,7 @@ import VscodeWrapper from '../controllers/vscodeWrapper';
import { ConnectionCredentials } from '../models/connectionCredentials';
import { ConnectionProfile } from '../models/connectionProfile';
import { ConnectionStore } from '../models/connectionStore';
import { IFirewallIpAddressRange } from '../models/contracts/firewall/firewallRequest';
import { ICreateFirewallRuleParams } from '../models/contracts/firewall/firewallRequest';
import { CredentialsQuickPickItemType, IConnectionCredentialsQuickPickItem, IConnectionProfile } from '../models/interfaces';
import * as Utils from '../models/utils';
import { Timer } from '../models/utils';
Expand Down Expand Up @@ -571,7 +571,36 @@ export class ConnectionUI {
}
}

private async promptForIpAddress(startIpAddress: string): Promise<IFirewallIpAddressRange> {
private async promptForFirewallRuleCreation(startIpAddress: string, server: string): Promise<ICreateFirewallRuleParams | undefined> {
function padTo2Digits(num: number): string {
return num.toString().padStart(2, '0');
}

// format as "YYYY-MM-DD_hh-mm-ss" (default Azure rulename format)
function formatDate(date: Date): string {
return (
[
date.getFullYear(),
padTo2Digits(date.getMonth() + 1),
padTo2Digits(date.getDate())
].join('-') +
'_' +
[
padTo2Digits(date.getHours()),
padTo2Digits(date.getMinutes()),
padTo2Digits(date.getSeconds())
].join('-')
);
}

let azureAccountChoices: INameValueChoice[] = ConnectionProfile.getAccountChoices(this._accountStore);
let tenantChoices: INameValueChoice[] = [];
let defaultFirewallRuleName = `ClientIPAddress_${formatDate(new Date())}`;

let accountAnswer: IAccount;
let tenantIdAnswer: string;
let firewallRuleNameAnswer: string;

let questions: IQuestion[] = [
{
type: QuestionTypes.input,
Expand All @@ -597,31 +626,74 @@ export class ConnectionUI {
}
},
default: startIpAddress
},
{
type: QuestionTypes.input,
name: LocalizedConstants.firewallRuleNamePrompt,
message: LocalizedConstants.firewallRuleNamePrompt,
placeHolder: defaultFirewallRuleName,
validate: (value: string) => {
if (!value.match(constants.ruleNameRegex)) {
return LocalizedConstants.msgInvalidRuleName;
}
firewallRuleNameAnswer = value;
},
default: defaultFirewallRuleName
}, ,
{
type: QuestionTypes.expand,
name: LocalizedConstants.aad,
message: LocalizedConstants.azureChooseAccount,
choices: azureAccountChoices,
onAnswered: async (value: IAccount) => {
accountAnswer = value;
let account = value;
tenantChoices.push(...account?.properties?.tenants!.map(t => ({ name: t.displayName, value: t })));
if (tenantChoices.length === 1) {
tenantIdAnswer = tenantChoices[0].value.id;
}
}
},
{
type: QuestionTypes.expand,
name: LocalizedConstants.tenant,
message: LocalizedConstants.azureChooseTenant,
choices: tenantChoices,
shouldPrompt: () => tenantChoices.length > 1,
onAnswered: (value: ITenant) => {
tenantIdAnswer = value.id;
}
}
];

// Prompt and return the value if the user confirmed
return this._prompter.prompt(questions).then((answers: { [questionId: string]: string }) => {
if (answers) {
let result: IFirewallIpAddressRange = {
startIpAddress: answers[LocalizedConstants.startIpAddressPrompt] ?
answers[LocalizedConstants.startIpAddressPrompt] : startIpAddress,
endIpAddress: answers[LocalizedConstants.endIpAddressPrompt] ?
answers[LocalizedConstants.endIpAddressPrompt] : startIpAddress
};
return result;
}
});
let answers = await this._prompter.prompt(questions);
if (answers) {
let result: ICreateFirewallRuleParams = {
account: accountAnswer,
startIpAddress: answers[LocalizedConstants.startIpAddressPrompt] ?
answers[LocalizedConstants.startIpAddressPrompt] as string : startIpAddress,
endIpAddress: answers[LocalizedConstants.endIpAddressPrompt] ?
answers[LocalizedConstants.endIpAddressPrompt] as string : startIpAddress,
firewallRuleName: firewallRuleNameAnswer,
serverName: server,
securityTokenMappings: await this.connectionManager.accountService.createSecurityTokenMapping(accountAnswer, tenantIdAnswer)
};
return result;
} else {
return undefined;
}
}

private async createFirewallRule(serverName: string, ipAddress: string): Promise<boolean> {
let result = await this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn,
let result = await this._vscodeWrapper.showInformationMessage(
Utils.formatString(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn, ipAddress, serverName),
LocalizedConstants.createFirewallRuleLabel);
if (result === LocalizedConstants.createFirewallRuleLabel) {
const firewallService = this.connectionManager.firewallService;
let ipRange = await this.promptForIpAddress(ipAddress);
if (ipRange) {
let firewallResult = await firewallService.createFirewallRule(serverName, ipRange.startIpAddress, ipRange.endIpAddress);
let params = await this.promptForFirewallRuleCreation(ipAddress, serverName);
if (params) {
let firewallResult = await firewallService.createFirewallRule(params);
if (firewallResult.result) {
this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptFirewallRuleCreated);
return true;
Expand Down Expand Up @@ -729,3 +801,4 @@ export class ConnectionUI {
});
}
}

Loading

0 comments on commit 3b0a3a8

Please sign in to comment.