Skip to content

Commit

Permalink
feat(providers): flatten multi-providers when using useExisting (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkluijk authored Sep 21, 2024
1 parent 9b8a0aa commit 48df581
Show file tree
Hide file tree
Showing 6 changed files with 479 additions and 55 deletions.
6 changes: 3 additions & 3 deletions src/container.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ describe("Container API", () => {
});

it("inject", () => {
expect(() => inject(MyService)).toThrowError();
expect(() => inject(MyService)).toThrowError('You can only invoke inject() from the injection context');

const container = new Container();
const token = new InjectionToken<MyService>("some-token");

expect(() => container.get(token)).toThrowError();
expect(() => container.get(token)).toThrowError('No provider(s) found');

container.bind({
provide: token,
Expand All @@ -34,7 +34,7 @@ describe("Container API", () => {
});

it("injectAsync", async () => {
expect(() => injectAsync(MyService)).toThrowError();
expect(() => injectAsync(MyService)).toThrowError('You can only invoke injectAsync() from the injection context');

const container = new Container();
const token = new InjectionToken<string>("some-token");
Expand Down
97 changes: 70 additions & 27 deletions src/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ export class Container {
private providers: ProviderMap = new Map();
private singletons: SingletonMap = new Map();

bindAll<T>(...providers: Provider<T>[]): this {
providers.forEach((it) => this.bind(it));
return this;
}

bind<T>(provider: Provider<T>): this {
const token = isConstructorProvider(provider) ? provider : provider.provide;
const multi = isMultiProvider(provider);
Expand Down Expand Up @@ -83,12 +88,18 @@ export class Container {

this.singletons.set(
token,
providers.map((it) => construct(it, this)),
providers.flatMap((it) => construct(it, this)),
);
}

if (options?.multi) {
return assertPresent(this.singletons.get(token));
const singletons = assertPresent(this.singletons.get(token));
if (options?.multi === true) {
return singletons;
} else if (singletons.length > 1) {
throw Error(
`Requesting a single value for ${toString(token)}, but multiple values were provided. ` +
`Consider passing "{ multi: true }" to inject all values, or adjust your bindings accordingly.`,
);
} else {
return assertPresent(this.singletons.get(token)?.at(0));
}
Expand Down Expand Up @@ -118,13 +129,19 @@ export class Container {

if (!this.singletons.has(token)) {
const values = await Promise.all(existingProviders.map((it) => constructAsync(it, this)));
this.singletons.set(token, values);
this.singletons.set(token, values.flat());
}

if (options?.multi) {
return Promise.all(assertPresent(this.singletons.get(token)).map((it) => promisify(it)));
const singletons = assertPresent(this.singletons.get(token));
if (options?.multi === true) {
return Promise.all(singletons.map((it) => promisify(it)));
} else if (singletons.length > 1) {
throw Error(
`Requesting a single value for ${toString(token)}, but multiple values were provided. ` +
`Consider passing "{ multi: true }" to inject all values, or adjust your bindings accordingly.`,
);
} else {
return promisify(assertPresent(this.singletons.get(token)?.at(0)));
return promisify(singletons.at(0));
}
}

Expand All @@ -141,21 +158,35 @@ export class Container {
.forEach((targetClass) => {
this.bind({
provide: targetClass,
multi: true,
useClass: targetClass,
multi: true,
});
});

targetClasses
.filter((it) => it !== token)
.forEach((targetClass) => {
this.bind({
provide: token,
multi: true,
useExisting: targetClass,
});
// inheritance support: only register immediate subclasses of the token with useExisting
const immediateSubclasses = [
...new Set(
targetClasses
.filter((targetClass) => targetClass !== token)
// .filter((targetClass) => Object.getPrototypeOf(targetClass) === token)
.map((targetClass) => {
let currentClass = targetClass;
while (Object.getPrototypeOf(currentClass) && Object.getPrototypeOf(currentClass) !== token) {
currentClass = Object.getPrototypeOf(currentClass);
}
return currentClass;
}),
),
];

immediateSubclasses.forEach((immediateSubClass) => {
this.bind({
provide: token,
useExisting: immediateSubClass,
multi: true,
});
} else if (isInjectionToken(token) && token.options?.factory) {
});
} else if (!this.providers.has(token) && isInjectionToken(token) && token.options?.factory) {
if (!token.options.async) {
this.bind({
provide: token,
Expand All @@ -179,6 +210,7 @@ export function inject<T>(token: Token<T>, options: { optional: true }): T | und
export function inject<T>(token: Token<T>): T;
export function inject<T>(token: Token<T>, options?: { optional: boolean }): T | undefined {
if (currentScope === undefined) {
if (options?.optional) return undefined;
throw new Error("You can only invoke inject() from the injection context");
}
return currentScope.get(token, options);
Expand All @@ -188,12 +220,13 @@ export function injectAsync<T>(token: Token<T>, options: { optional: true }): Pr
export function injectAsync<T>(token: Token<T>): Promise<T>;
export function injectAsync<T>(token: Token<T>, options?: { optional: boolean }): Promise<T | undefined> {
if (currentScope === undefined) {
if (options?.optional) return Promise.resolve(undefined);
throw new Error("You can only invoke injectAsync() from the injection context");
}
return currentScope.getAsync(token, options);
}

function construct<T>(provider: Provider<T>, scope: Container): Promise<T> | T {
function construct<T>(provider: Provider<T>, scope: Container): T[] {
const originalScope = currentScope;
try {
currentScope = scope;
Expand All @@ -203,11 +236,11 @@ function construct<T>(provider: Provider<T>, scope: Container): Promise<T> | T {
}
}

async function constructAsync<T>(provider: Provider<T>, scope: Container): Promise<T> {
async function constructAsync<T>(provider: Provider<T>, scope: Container): Promise<T[]> {
const originalScope = currentScope;
try {
currentScope = scope;
return await promisify(doConstruct(provider, scope));
return await doConstructAsync(provider, scope);
} finally {
currentScope = originalScope;
}
Expand All @@ -218,17 +251,27 @@ async function promisify<T>(value: T | Promise<T>): Promise<T> {
return new Promise<T>((resolve) => resolve(value));
}

function doConstruct<T>(provider: Provider<T>, scope: Container): T | Promise<T> {
function doConstruct<T>(provider: Provider<T>, scope: Container): T[] {
if (isConstructorProvider(provider)) {
return new provider();
return [new provider()];
} else if (isClassProvider(provider)) {
return new provider.useClass();
return [new provider.useClass()];
} else if (isValueProvider(provider)) {
return provider.useValue;
} else if (isFactoryProvider(provider)) {
return provider.useFactory();
return [provider.useValue];
} else if (isFactoryProvider(provider) && !provider.async) {
return [provider.useFactory()];
} else if (isFactoryProvider(provider) && provider.async) {
throw Error("Invalid state");
} else {
return scope.get(provider.useExisting);
return scope.get(provider.useExisting, { multi: true });
}
}

async function doConstructAsync<T>(provider: Provider<T>, scope: Container): Promise<T[]> {
if (isFactoryProvider(provider) && provider.async) {
return await provider.useFactory().then((it) => [it]);
} else {
return doConstruct(provider, scope);
}
}

Expand Down
Loading

0 comments on commit 48df581

Please sign in to comment.