Skip to content

Commit

Permalink
Refactor combiner (#10533)
Browse files Browse the repository at this point in the history
* Refactor combiner

* Remove combiner timeouts

* fix lint

---------

Co-authored-by: Alec Schaefer <alec@cLabs.co>
  • Loading branch information
gastonponti and alecps committed Sep 7, 2023
1 parent 619ae9d commit 66b1fb3
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 146 deletions.
19 changes: 19 additions & 0 deletions packages/phone-number-privacy/combiner/src/common/error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { ErrorType } from '@celo/phone-number-privacy-common'

export class OdisError extends Error {
constructor(readonly code: ErrorType, readonly parent?: Error, readonly status: number = 500) {
// This is necessary when extending Error Classes
super(code) // 'Error' breaks prototype chain here
Object.setPrototypeOf(this, new.target.prototype) // restore prototype chain
}
}

export function wrapError<T>(
valueOrError: Promise<T>,
code: ErrorType,
status: number = 500
): Promise<T> {
return valueOrError.catch((parentErr) => {
throw new OdisError(code, parentErr, status)
})
}
141 changes: 129 additions & 12 deletions packages/phone-number-privacy/combiner/src/common/handlers.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import {
ErrorMessage,
ErrorType,
OdisRequest,
OdisResponse,
PnpQuotaStatus,
send,
// tslint:disable-next-line: ordered-imports
SequentialDelayDomainState,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import opentelemetry, { SpanStatusCode } from '@opentelemetry/api'
import { SemanticAttributes } from '@opentelemetry/semantic-conventions'
import Logger from 'bunyan'
import { Request, Response } from 'express'
import { performance, PerformanceObserver } from 'perf_hooks'
import { sendFailure } from './io'
import { getCombinerVersion } from '../config'
import { OdisError } from './error'

const tracer = opentelemetry.trace.getTracer('combiner-tracer')

export interface Locals {
logger: Logger
Expand All @@ -18,31 +28,62 @@ export type PromiseHandler<R extends OdisRequest> = (
res: Response<OdisResponse<R>, Locals>
) => Promise<void>

type ParentHandler = (req: Request<{}, {}, any>, res: Response<any, Locals>) => Promise<void>

export function catchErrorHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): ParentHandler {
): PromiseHandler<R> {
return async (req, res) => {
const logger: Logger = res.locals.logger
try {
await handler(req, res)
} catch (err) {
const logger: Logger = res.locals.logger
logger.error(ErrorMessage.CAUGHT_ERROR_IN_ENDPOINT_HANDLER)
logger.error(err)
if (!res.headersSent) {
logger.info('Responding with error in outer endpoint handler')
res.status(500).json({
success: false,
error: ErrorMessage.UNKNOWN_ERROR,
})
if (err instanceof OdisError) {
sendFailure(err.code, err.status, res, req.url)
} else {
sendFailure(ErrorMessage.UNKNOWN_ERROR, 500, res, req.url)
}
} else {
logger.error(ErrorMessage.ERROR_AFTER_RESPONSE_SENT)
}
}
}
}

export function tracingHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
return tracer.startActiveSpan(
req.url,
{
attributes: {
[SemanticAttributes.HTTP_ROUTE]: req.path,
[SemanticAttributes.HTTP_METHOD]: req.method,
[SemanticAttributes.HTTP_CLIENT_IP]: req.ip,
},
},
async (span) => {
try {
await handler(req, res)
span.setStatus({
code: SpanStatusCode.OK,
})
} catch (err: any) {
span.setStatus({
code: SpanStatusCode.ERROR,
message: err instanceof Error ? err.message : 'Fail',
})
throw err
} finally {
span.end()
}
}
)
}
}

export function meteringHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): PromiseHandler<R> {
Expand Down Expand Up @@ -86,9 +127,85 @@ export function meteringHandler<R extends OdisRequest>(
}
}

export function timeoutHandler<R extends OdisRequest>(
timeoutMs: number,
handler: PromiseHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
const timeoutSignal = (AbortSignal as any).timeout(timeoutMs)
timeoutSignal.addEventListener(
'abort',
() => {
if (!res.headersSent) {
sendFailure(ErrorMessage.TIMEOUT_FROM_SIGNER, 500, res, req.url)
}
},
{ once: true }
)

await handler(req, res)
}
}

export async function disabledHandler<R extends OdisRequest>(
_: Request<{}, {}, R>,
req: Request<{}, {}, R>,
response: Response<OdisResponse<R>, Locals>
): Promise<void> {
sendFailure(WarningMessage.API_UNAVAILABLE, 503, response)
sendFailure(WarningMessage.API_UNAVAILABLE, 503, response, req.url)
}

export function sendFailure(
error: ErrorType,
status: number,
response: Response,
_endpoint: string,
body?: Record<any, any> // TODO remove any
) {
send(
response,
{
success: false,
version: getCombinerVersion(),
error,
...body,
},
status,
response.locals.logger
)
}

export interface Result<R extends OdisRequest> {
status: number
body: OdisResponse<R>
}

export type ResultHandler<R extends OdisRequest> = (
request: Request<{}, {}, R>,
res: Response<OdisResponse<R>, Locals>
) => Promise<Result<R>>

export function resultHandler<R extends OdisRequest>(
resHandler: ResultHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
const result = await resHandler(req, res)
send(res, result.body, result.status, res.locals.logger)
}
}

export function errorResult(
status: number,
error: string,
quotaStatus?: PnpQuotaStatus | { status: SequentialDelayDomainState }
): Result<any> {
// TODO remove any
return {
status,
body: {
success: false,
version: getCombinerVersion(),
error,
...quotaStatus,
},
}
}
19 changes: 2 additions & 17 deletions packages/phone-number-privacy/combiner/src/common/io.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import {
ErrorType,
getRequestKeyVersion,
KEY_VERSION_HEADER,
KeyVersionInfo,
OdisRequest,
OdisResponse,
requestHasValidKeyVersion,
send,
SignerEndpoint,
} from '@celo/phone-number-privacy-common'
import Logger from 'bunyan'
import { Request, Response } from 'express'
import { Request } from 'express'
import * as http from 'http'
import * as https from 'https'
import fetch, { Response as FetchResponse } from 'node-fetch'
import { performance } from 'perf_hooks'
import { getCombinerVersion, OdisConfig } from '../config'
import { OdisConfig } from '../config'
import { isAbortError, Signer } from './combine'

const httpAgent = new http.Agent({ keepAlive: true })
Expand Down Expand Up @@ -116,16 +114,3 @@ async function measureTime<T>(name: string, fn: () => Promise<T>): Promise<T> {
performance.measure(name, start, end)
}
}

export function sendFailure(error: ErrorType, status: number, response: Response<any>) {
send(
response,
{
success: false,
version: getCombinerVersion(),
error,
},
status,
response.locals.logger
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@ import {
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDisableDomainRequestAuthenticity,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'
import { errorResult, ResultHandler } from '../../../common/handlers'
import { getKeyVersionInfo } from '../../../common/io'
import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDisableDomainHandler(
export function disableDomain(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DisableDomainRequest> {
): ResultHandler<DisableDomainRequest> {
return async (request, response) => {
if (!disableDomainRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
return errorResult(400, WarningMessage.INVALID_INPUT)
}

if (!verifyDisableDomainRequestAuthenticity(request.body)) {
sendFailure(WarningMessage.UNAUTHENTICATED_USER, 401, response)
return
return errorResult(401, WarningMessage.UNAUTHENTICATED_USER)
}

// TODO remove?
Expand All @@ -57,18 +54,14 @@ export function createDisableDomainHandler(
signers.length
)
if (disableDomainStatus.disabled) {
send(
response,
{
return {
status: 200,
body: {
success: true,
version: getCombinerVersion(),
status: disableDomainStatus,
},
200,
response.locals.logger
)

return
}
}
} catch (err) {
response.locals.logger.error(
Expand All @@ -77,6 +70,6 @@ export function createDisableDomainHandler(
)
}

sendFailure(ErrorMessage.THRESHOLD_DISABLE_DOMAIN_FAILURE, maxErrorCode ?? 500, response)
return errorResult(maxErrorCode ?? 500, ErrorMessage.THRESHOLD_DISABLE_DOMAIN_FAILURE)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@ import {
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDomainQuotaStatusRequestAuthenticity,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'
import { errorResult, ResultHandler } from '../../../common/handlers'
import { getKeyVersionInfo } from '../../../common/io'
import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDomainQuotaHandler(
export function domainQuota(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DomainQuotaStatusRequest> {
): ResultHandler<DomainQuotaStatusRequest> {
return async (request, response) => {
if (!domainQuotaStatusRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
return errorResult(400, WarningMessage.INVALID_INPUT)
}

if (!verifyDomainQuotaStatusRequestAuthenticity(request.body)) {
sendFailure(WarningMessage.UNAUTHENTICATED_USER, 401, response)
return
return errorResult(401, WarningMessage.UNAUTHENTICATED_USER)
}

// TODO remove?
Expand All @@ -49,21 +46,18 @@ export function createDomainQuotaHandler(
logDomainResponseDiscrepancies(response.locals.logger, signerResponses)
if (signerResponses.length >= keyVersionInfo.threshold) {
try {
send(
response,
{
return {
status: 200,
body: {
success: true,
version: getCombinerVersion(),
status: findThresholdDomainState(keyVersionInfo, signerResponses, signers.length),
},
200,
response.locals.logger
)
return
}
} catch (err) {
response.locals.logger.error(err, 'Error combining signer quota status responses')
}
}
sendFailure(ErrorMessage.THRESHOLD_DOMAIN_QUOTA_STATUS_FAILURE, maxErrorCode ?? 500, response)
return errorResult(maxErrorCode ?? 500, ErrorMessage.THRESHOLD_DOMAIN_QUOTA_STATUS_FAILURE)
}
}
Loading

0 comments on commit 66b1fb3

Please sign in to comment.