Skip to content

Commit

Permalink
cosmetic cleanup from PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
alecps committed Aug 23, 2023
1 parent eb76a81 commit f47f7c6
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 264 deletions.
35 changes: 22 additions & 13 deletions packages/phone-number-privacy/combiner/src/common/combine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
OdisRequest,
OdisResponse,
responseHasExpectedKeyVersion,
SignerEndpoint,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import Logger from 'bunyan'
Expand All @@ -17,15 +18,19 @@ export interface Signer {
fallbackUrl?: string
}

export interface ThresholdCallToSignersOptions<R extends OdisRequest> {
signers: Signer[]
endpoint: SignerEndpoint
requestTimeoutMS: number
shouldCheckKeyVersion: boolean
keyVersionInfo: KeyVersionInfo
request: Request<{}, {}, R>
responseSchema: t.Type<OdisResponse<R>, OdisResponse<R>, unknown>
}

export async function thresholdCallToSigners<R extends OdisRequest>(
logger: Logger,
signers: Signer[],
endpoint: string,
request: Request<{}, {}, R>,
keyVersionInfo: KeyVersionInfo,
requestTimeoutMS: number,
responseSchema: t.Type<OdisResponse<R>, OdisResponse<R>, unknown>,
shouldCheckKeyVersion: boolean = false,
options: ThresholdCallToSignersOptions<R>,
processResult: (res: OdisResponse<R>) => Promise<boolean> = (_) => Promise.resolve(false)
): Promise<{ signerResponses: Array<SignerResponse<R>>; maxErrorCode?: number }> {
const obs = new PerformanceObserver((list) => {
Expand All @@ -42,6 +47,16 @@ export async function thresholdCallToSigners<R extends OdisRequest>(
})
obs.observe({ entryTypes: ['measure'], buffered: false })

const {
signers,
endpoint,
requestTimeoutMS,
shouldCheckKeyVersion,
keyVersionInfo,
request,
responseSchema,
} = options

const manualAbort = new AbortController()
const timeoutSignal = AbortSignal.timeout(requestTimeoutMS)
const abortSignal = (AbortSignal as any).any([manualAbort.signal, timeoutSignal]) as AbortSignal
Expand Down Expand Up @@ -81,7 +96,6 @@ export async function thresholdCallToSigners<R extends OdisRequest>(
return
}

// if given key version, check that
if (
shouldCheckKeyVersion &&
!responseHasExpectedKeyVersion(signerFetchResult, keyVersionInfo.keyVersion, logger)
Expand Down Expand Up @@ -120,16 +134,11 @@ export async function thresholdCallToSigners<R extends OdisRequest>(
logger.error({ signer }, ErrorMessage.SIGNER_REQUEST_ERROR)
logger.error({ signer, err })

// Tracking failed request count via signer url prevents
// double counting the same failed request by mistake
errorCount++
if (signers.length - errorCount < requiredThreshold) {
logger.warn('Not possible to reach a threshold of signer responses. Failing fast')
manualAbort.abort()
}

// TODO (mcortesi) doesn't seem we need to fail at first error
// throw err
}
}
})
Expand Down
5 changes: 3 additions & 2 deletions packages/phone-number-privacy/combiner/src/common/io.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import {
ErrorType,
getRequestKeyVersion,
KeyVersionInfo,
KEY_VERSION_HEADER,
KeyVersionInfo,
OdisRequest,
OdisResponse,
requestHasValidKeyVersion,
send,
SignerEndpoint,
} from '@celo/phone-number-privacy-common'
import Logger from 'bunyan'
import { Request, Response } from 'express'
Expand Down Expand Up @@ -66,7 +67,7 @@ export function getKeyVersionInfo(

export async function fetchSignerResponseWithFallback<R extends OdisRequest>(
signer: Signer,
signerEndpoint: string,
signerEndpoint: SignerEndpoint,
keyVersion: number,
request: Request<{}, {}, R>,
logger: Logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
disableDomainResponseSchema,
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDisableDomainRequestAuthenticity,
Expand All @@ -13,19 +14,16 @@ import {
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'

import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponsesDiscrepancies } from '../../services/log-responses'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDisableDomainHandler(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DisableDomainRequest> {
const requestSchema = disableDomainRequestSchema(DomainSchema)
const signerEndpoint = CombinerEndpoint.DISABLE_DOMAIN
return async (request, response) => {
if (!requestSchema.is(request.body)) {
if (!disableDomainRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
}
Expand All @@ -35,19 +33,23 @@ export function createDisableDomainHandler(
return
}

// TODO remove?
const keyVersionInfo = getKeyVersionInfo(request, config, response.locals.logger)

const { signerResponses, maxErrorCode } = await thresholdCallToSigners(
const { signerResponses, maxErrorCode } = await thresholdCallToSigners<DisableDomainRequest>(
response.locals.logger,
signers,
signerEndpoint,
request,
keyVersionInfo,
config.odisServices.timeoutMilliSeconds,
disableDomainResponseSchema(SequentialDelayDomainStateSchema)
{
signers,
endpoint: getSignerEndpoint(CombinerEndpoint.DISABLE_DOMAIN),
request,
keyVersionInfo,
requestTimeoutMS: config.odisServices.timeoutMilliSeconds,
responseSchema: disableDomainResponseSchema(SequentialDelayDomainStateSchema),
shouldCheckKeyVersion: false,
}
)

logDomainResponsesDiscrepancies(response.locals.logger, signerResponses)
logDomainResponseDiscrepancies(response.locals.logger, signerResponses)
try {
const disableDomainStatus = findThresholdDomainState(
keyVersionInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
domainQuotaStatusResponseSchema,
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDomainQuotaStatusRequestAuthenticity,
Expand All @@ -13,19 +14,16 @@ import {
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'

import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponsesDiscrepancies } from '../../services/log-responses'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDomainQuotaHandler(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DomainQuotaStatusRequest> {
const requestSchema = domainQuotaStatusRequestSchema(DomainSchema)
const signerEndpoint = CombinerEndpoint.DOMAIN_QUOTA_STATUS
return async (request, response) => {
if (!requestSchema.is(request.body)) {
if (!domainQuotaStatusRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
}
Expand All @@ -35,19 +33,20 @@ export function createDomainQuotaHandler(
return
}

// TODO remove?
const keyVersionInfo = getKeyVersionInfo(request, config, response.locals.logger)

const { signerResponses, maxErrorCode } = await thresholdCallToSigners(
response.locals.logger,
const { signerResponses, maxErrorCode } = await thresholdCallToSigners(response.locals.logger, {
signers,
signerEndpoint,
endpoint: getSignerEndpoint(CombinerEndpoint.DOMAIN_QUOTA_STATUS),
request,
keyVersionInfo,
config.odisServices.timeoutMilliSeconds,
domainQuotaStatusResponseSchema(SequentialDelayDomainStateSchema)
)
requestTimeoutMS: config.odisServices.timeoutMilliSeconds,
responseSchema: domainQuotaStatusResponseSchema(SequentialDelayDomainStateSchema),
shouldCheckKeyVersion: false,
})

logDomainResponsesDiscrepancies(response.locals.logger, signerResponses)
logDomainResponseDiscrepancies(response.locals.logger, signerResponses)
if (signerResponses.length >= keyVersionInfo.threshold) {
try {
send(
Expand All @@ -65,6 +64,6 @@ export function createDomainQuotaHandler(
response.locals.logger.error(err, 'Error combining signer quota status responses')
}
}
sendFailure(ErrorMessage.THRESHOLD_DISABLE_DOMAIN_FAILURE, maxErrorCode ?? 500, response)
sendFailure(ErrorMessage.THRESHOLD_DOMAIN_QUOTA_STATUS_FAILURE, maxErrorCode ?? 500, response)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
DomainSchema,
ErrorMessage,
ErrorType,
getSignerEndpoint,
OdisResponse,
send,
SequentialDelayDomainStateSchema,
Expand All @@ -15,26 +16,24 @@ import {
import assert from 'node:assert'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { DomainCryptoClient } from '../../../common/crypto-clients/domain-crypto-client'

import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, requestHasSupportedKeyVersion, sendFailure } from '../../../common/io'

import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponsesDiscrepancies } from '../../services/log-responses'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDomainSignHandler(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DomainRestrictedSignatureRequest> {
const requestSchema = domainRestrictedSignatureRequestSchema(DomainSchema)
const signerEndpoint = CombinerEndpoint.DOMAIN_SIGN
return async (request, response) => {
if (!requestSchema.is(request.body)) {
const { logger } = response.locals

if (!domainRestrictedSignatureRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
}
if (!requestHasSupportedKeyVersion(request, config, response.locals.logger)) {
if (!requestHasSupportedKeyVersion(request, config, logger)) {
sendFailure(WarningMessage.INVALID_KEY_VERSION_REQUEST, 400, response)
return
}
Expand All @@ -47,11 +46,10 @@ export function createDomainSignHandler(
return
}

const keyVersionInfo = getKeyVersionInfo(request, config, response.locals.logger)
const keyVersionInfo = getKeyVersionInfo(request, config, logger)
const crypto = new DomainCryptoClient(keyVersionInfo)

const logger = response.locals.logger
const processRequest = async (
const processResult = async (
res: OdisResponse<DomainRestrictedSignatureRequest>
): Promise<boolean> => {
assert(res.success)
Expand Down Expand Up @@ -86,17 +84,19 @@ export function createDomainSignHandler(

const { signerResponses, maxErrorCode } = await thresholdCallToSigners(
response.locals.logger,
signers,
signerEndpoint,
request,
keyVersionInfo,
config.odisServices.timeoutMilliSeconds,
domainRestrictedSignatureResponseSchema(SequentialDelayDomainStateSchema),
true,
processRequest
{
signers,
endpoint: getSignerEndpoint(CombinerEndpoint.DOMAIN_SIGN),
request,
keyVersionInfo,
requestTimeoutMS: config.odisServices.timeoutMilliSeconds,
responseSchema: domainRestrictedSignatureResponseSchema(SequentialDelayDomainStateSchema),
shouldCheckKeyVersion: true,
},
processResult
)

logDomainResponsesDiscrepancies(response.locals.logger, signerResponses)
logDomainResponseDiscrepancies(response.locals.logger, signerResponses)

if (crypto.hasSufficientSignatures()) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import {
DomainRequest,
DomainRestrictedSignatureRequest,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import { DomainRequest, WarningMessage } from '@celo/phone-number-privacy-common'
import Logger from 'bunyan'
import { SignerResponse } from '../../common/io'

export function logDomainResponsesDiscrepancies(
export function logDomainResponseDiscrepancies<R extends DomainRequest>(
logger: Logger,
responses: Array<SignerResponse<DomainRequest | DomainRestrictedSignatureRequest>>
responses: Array<SignerResponse<R>>
) {
const parsedResponses: Array<{
signerUrl: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
CombinerEndpoint,
DataEncryptionKeyFetcher,
ErrorMessage,
getSignerEndpoint,
hasValidAccountParam,
isBodyReasonablySized,
PnpQuotaRequest,
Expand All @@ -15,23 +16,18 @@ import { Request } from 'express'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'

import { getCombinerVersion, OdisConfig } from '../../../config'
import {
logFailOpenResponses,
logPnpSignerResponseDiscrepancies,
} from '../../services/log-responses'
import { logPnpSignerResponseDiscrepancies } from '../../services/log-responses'
import { findCombinerQuotaState } from '../../services/threshold-state'

export function createPnpQuotaHandler(
signers: Signer[],
config: OdisConfig,
dekFetcher: DataEncryptionKeyFetcher
): PromiseHandler<PnpQuotaRequest> {
const signerEndpoint = CombinerEndpoint.PNP_QUOTA

return async (request, response) => {
const logger = response.locals.logger

if (!validateRequest(request)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
Expand All @@ -41,19 +37,23 @@ export function createPnpQuotaHandler(
sendFailure(WarningMessage.UNAUTHENTICATED_USER, 401, response)
return
}

// TODO remove?
const keyVersionInfo = getKeyVersionInfo(request, config, logger)

const { signerResponses, maxErrorCode } = await thresholdCallToSigners(
logger,
const { signerResponses, maxErrorCode } = await thresholdCallToSigners(logger, {
signers,
signerEndpoint,
endpoint: getSignerEndpoint(CombinerEndpoint.PNP_QUOTA),
request,
keyVersionInfo,
config.odisServices.timeoutMilliSeconds,
PnpQuotaResponseSchema
)
requestTimeoutMS: config.odisServices.timeoutMilliSeconds,
responseSchema: PnpQuotaResponseSchema,
shouldCheckKeyVersion: false,
})
const warnings = logPnpSignerResponseDiscrepancies(logger, signerResponses)
logFailOpenResponses(logger, signerResponses)

// TODO remove?
// logFailOpenResponses(logger, signerResponses)

const { threshold } = keyVersionInfo

Expand Down
Loading

0 comments on commit f47f7c6

Please sign in to comment.