Skip to content

Commit

Permalink
Typescript high-level api for Sets (#7471)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuantianDing authored Dec 5, 2024
1 parent a17d4e6 commit 4be4067
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 0 deletions.
127 changes: 127 additions & 0 deletions src/api/js/src/high-level/high-level.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,133 @@ describe('high-level', () => {
});
});

describe('sets', () => {
it('Example 1', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');

const conjecture = set.contains(a).and(set.contains(b)).implies(Z3.EmptySet(Z3.Int.sort()).neq(set));
await prove(conjecture);
});

it('Example 2', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');

const conjecture = set.contains(a).and(set.contains(b)).implies(Z3.Set.val([a, b], Z3.Int.sort()).subsetOf(set));
await prove(conjecture);
});

it('Example 3', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');

const conjecture = set.contains(a).and(set.contains(b)).and(Z3.Set.val([a, b], Z3.Int.sort()).eq(set));
await solve(conjecture);
});

it('Intersection 1', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');
const abset = Z3.Set.val([a, b], Z3.Int.sort());

const conjecture = set.intersect(abset).subsetOf(abset);
await prove(conjecture);
});

it('Intersection 2', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');
const abset = Z3.Set.val([a, b], Z3.Int.sort());

const conjecture = set.subsetOf(set.intersect(abset));
await solve(conjecture);
});

it('Union 1', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');
const abset = Z3.Set.val([a, b], Z3.Int.sort());

const conjecture = set.subsetOf(set.union(abset));
await prove(conjecture);
});

it('Union 2', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');
const abset = Z3.Set.val([a, b], Z3.Int.sort());

const conjecture = set.union(abset).subsetOf(abset);
await solve(conjecture);
});

it('Complement 1', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const a = Z3.Int.const('a');

const conjecture = set.complement().complement().eq(set)
await prove(conjecture);
});
it('Complement 2', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());
const a = Z3.Int.const('a');

const conjecture = set.contains(a).implies(Z3.Not(set.complement().contains(a)))
await prove(conjecture);
});

it('Difference', async () => {
const Z3 = api.Context('main');

const [set1, set2] = Z3.Set.consts('set1 set2', Z3.Int.sort());
const a = Z3.Int.const('a');

const conjecture = set1.contains(a).implies(Z3.Not(set2.diff(set1).contains(a)))

await prove(conjecture);
});

it('FullSet', async () => {
const Z3 = api.Context('main');

const set = Z3.Set.const('set', Z3.Int.sort());

const conjecture = set.complement().eq(Z3.FullSet(Z3.Int.sort()).diff(set));

await prove(conjecture);
});

it('SetDel', async () => {
const Z3 = api.Context('main');

const empty = Z3.Set.empty(Z3.Int.sort());
const [a, b] = Z3.Int.consts('a b');

const conjecture = empty.add(a).add(b).del(a).del(b).eq(empty);

await prove(conjecture);
});
});

describe('quantifiers', () => {
it('Basic Universal', async () => {
const Z3 = api.Context('main');
Expand Down
120 changes: 120 additions & 0 deletions src/api/js/src/high-level/high-level.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ import {
CoercibleToArith,
NonEmptySortArray,
FuncEntry,
SMTSetSort,
SMTSet,
} from './types';
import { allSatisfy, assert, assertExhaustive } from './utils';

Expand Down Expand Up @@ -795,6 +797,33 @@ export function createApi(Z3: Z3Core): Z3HighLevel {
return new ArrayImpl<[DomainSort], RangeSort>(check(Z3.mk_const_array(contextPtr, domain.ptr, value.ptr)));
},
};
const Set = {
// reference: https://z3prover.github.io/api/html/namespacez3py.html#a545f894afeb24caa1b88b7f2a324ee7e
sort<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSetSort<Name, ElemSort> {
return Array.sort(sort, Bool.sort());
},
const<ElemSort extends AnySort<Name>>(name: string, sort: ElemSort) : SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(
check(Z3.mk_const(contextPtr, _toSymbol(name), Array.sort(sort, Bool.sort()).ptr)),
);
},
consts<ElemSort extends AnySort<Name>>(names: string | string[], sort: ElemSort) : SMTSet<Name, ElemSort>[] {
if (typeof names === 'string') {
names = names.split(' ');
}
return names.map(name => Set.const(name, sort));
},
empty<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort> {
return EmptySet(sort);
},
val<ElemSort extends AnySort<Name>>(values: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>[], sort: ElemSort): SMTSet<Name, ElemSort> {
var result = EmptySet(sort);
for (const value of values) {
result = SetAdd(result, value);
}
return result;
}
}

////////////////
// Operations //
Expand Down Expand Up @@ -1249,6 +1278,49 @@ export function createApi(Z3: Z3Core): Z3HighLevel {
>;
}

function SetUnion<ElemSort extends AnySort<Name>>(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_set_union(contextPtr, args.map(arg => arg.ast))));
}

function SetIntersect<ElemSort extends AnySort<Name>>(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_set_intersect(contextPtr, args.map(arg => arg.ast))));
}

function SetDifference<ElemSort extends AnySort<Name>>(a: SMTSet<Name, ElemSort>, b: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_set_difference(contextPtr, a.ast, b.ast)));
}

function SetHasSize<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, size: bigint | number | string | IntNum<Name>): Bool<Name> {
const a = typeof size === 'object'? Int.sort().cast(size) : Int.sort().cast(size);
return new BoolImpl(check(Z3.mk_set_has_size(contextPtr, set.ast, a.ast)));
}

function SetAdd<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort> {
const arg = set.elemSort().cast(elem as any);
return new SetImpl<ElemSort>(check(Z3.mk_set_add(contextPtr, set.ast, arg.ast)));
}
function SetDel<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort> {
const arg = set.elemSort().cast(elem as any);
return new SetImpl<ElemSort>(check(Z3.mk_set_del(contextPtr, set.ast, arg.ast)));
}
function SetComplement<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_set_complement(contextPtr, set.ast)));
}

function EmptySet<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_empty_set(contextPtr, sort.ptr)));
}
function FullSet<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort> {
return new SetImpl<ElemSort>(check(Z3.mk_full_set(contextPtr, sort.ptr)));
}
function isMember<ElemSort extends AnySort<Name>>(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>, set: SMTSet<Name, ElemSort>): Bool<Name> {
const arg = set.elemSort().cast(elem as any);
return new BoolImpl(check(Z3.mk_set_member(contextPtr, arg.ast, set.ast)));
}
function isSubset<ElemSort extends AnySort<Name>>(a: SMTSet<Name, ElemSort>, b: SMTSet<Name, ElemSort>): Bool<Name> {
return new BoolImpl(check(Z3.mk_set_subset(contextPtr, a.ast, b.ast)));
}

class AstImpl<Ptr extends Z3_ast> implements Ast<Name, Ptr> {
declare readonly __typename: Ast['__typename'];
readonly ctx: Context<Name>;
Expand Down Expand Up @@ -2536,6 +2608,41 @@ export function createApi(Z3: Z3Core): Z3HighLevel {
}
}

class SetImpl<ElemSort extends Sort<Name>> extends ExprImpl<Z3_ast, ArraySortImpl<[ElemSort], BoolSort<Name>>> implements SMTSet<Name, ElemSort> {
declare readonly __typename: 'Array';

elemSort(): ElemSort {
return this.sort.domain();
}
union(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort> {
return SetUnion(this, ...args);
}
intersect(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort> {
return SetIntersect(this, ...args);
}
diff(b: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort> {
return SetDifference(this, b);
}
hasSize(size: string | number | bigint | IntNum<Name>): Bool<Name> {
return SetHasSize(this, size);
}
add(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort> {
return SetAdd(this, elem);
}
del(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort> {
return SetDel(this, elem);
}
complement(): SMTSet<Name, ElemSort> {
return SetComplement(this);
}
contains(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): Bool<Name> {
return isMember(elem, this);
}
subsetOf(b: SMTSet<Name, ElemSort>): Bool<Name> {
return isSubset(this, b);
}
}

class QuantifierImpl<
QVarSorts extends NonEmptySortArray<Name>,
QSort extends BoolSort<Name> | SMTArraySort<Name, QVarSorts>,
Expand Down Expand Up @@ -2917,6 +3024,7 @@ export function createApi(Z3: Z3Core): Z3HighLevel {
Real,
BitVec,
Array,
Set,

////////////////
// Operations //
Expand Down Expand Up @@ -2979,6 +3087,18 @@ export function createApi(Z3: Z3Core): Z3HighLevel {
// Loading //
/////////////
ast_from_string,

SetUnion,
SetIntersect,
SetDifference,
SetHasSize,
SetAdd,
SetDel,
SetComplement,
EmptySet,
FullSet,
isMember,
isSubset,
};
cleanup.register(ctx, () => Z3.del_context(contextPtr));
return ctx;
Expand Down
83 changes: 83 additions & 0 deletions src/api/js/src/high-level/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ export interface Context<Name extends string = 'main'> {
readonly BitVec: BitVecCreation<Name>;
/** @category Expressions */
readonly Array: SMTArrayCreation<Name>;
/** @category Expressions */
readonly Set: SMTSetCreation<Name>;

////////////////
// Operations //
Expand Down Expand Up @@ -611,6 +613,39 @@ export interface Context<Name extends string = 'main'> {
substitute(t: Expr<Name>, ...substitutions: [Expr<Name>, Expr<Name>][]): Expr<Name>;

simplify(expr: Expr<Name>): Promise<Expr<Name>>;

/** @category Operations */
SetUnion<ElemSort extends AnySort<Name>>(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort>;

/** @category Operations */
SetIntersect<ElemSort extends AnySort<Name>>(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort>;

/** @category Operations */
SetDifference<ElemSort extends AnySort<Name>>(a: SMTSet<Name, ElemSort>, b: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort>;

/** @category Operations */
SetHasSize<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, size: bigint | number | string | IntNum<Name>): Bool<Name>;

/** @category Operations */
SetAdd<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort>;

/** @category Operations */
SetDel<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>, elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort>;

/** @category Operations */
SetComplement<ElemSort extends AnySort<Name>>(set: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort>;

/** @category Operations */
EmptySet<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort>;

/** @category Operations */
FullSet<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort>;

/** @category Operations */
isMember<ElemSort extends AnySort<Name>>(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>, set: SMTSet<Name, ElemSort>): Bool<Name>;

/** @category Operations */
isSubset<ElemSort extends AnySort<Name>>(a: SMTSet<Name, ElemSort>, b: SMTSet<Name, ElemSort>): Bool<Name>;
}

export interface Ast<Name extends string = 'main', Ptr = unknown> {
Expand Down Expand Up @@ -1568,6 +1603,54 @@ export interface SMTArray<
): SMTArray<Name, DomainSort, RangeSort>;
}

/**
* Set Implemented using Arrays
*
* @typeParam ElemSort The sort of the element of the set
* @category Sets
*/
export type SMTSetSort<Name extends string = 'main', ElemSort extends AnySort<Name> = Sort<Name>> = SMTArraySort<Name, [ElemSort], BoolSort<Name>>;


/** @category Sets*/
export interface SMTSetCreation<Name extends string> {
sort<ElemSort extends AnySort<Name>>(elemSort: ElemSort): SMTSetSort<Name, ElemSort>;

const<ElemSort extends AnySort<Name>>(name: string, elemSort: ElemSort): SMTSet<Name, ElemSort>;

consts<ElemSort extends AnySort<Name>>(names: string | string[], elemSort: ElemSort): SMTSet<Name, ElemSort>[];

empty<ElemSort extends AnySort<Name>>(sort: ElemSort): SMTSet<Name, ElemSort>;

val<ElemSort extends AnySort<Name>>(values: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>[], sort: ElemSort): SMTSet<Name, ElemSort>;
}

/**
* Represents Set expression
*
* @typeParam ElemSort The sort of the element of the set
* @category Arrays
*/
export interface SMTSet<Name extends string = 'main', ElemSort extends AnySort<Name> = Sort<Name>> extends Expr<Name, SMTSetSort<Name, ElemSort>, Z3_ast> {
readonly __typename: 'Array';

elemSort(): ElemSort;

union(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort>;
intersect(...args: SMTSet<Name, ElemSort>[]): SMTSet<Name, ElemSort>;
diff(b: SMTSet<Name, ElemSort>): SMTSet<Name, ElemSort>;

hasSize(size: bigint | number | string | IntNum<Name>): Bool<Name>;

add(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort>;
del(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): SMTSet<Name, ElemSort>;
complement(): SMTSet<Name, ElemSort>;

contains(elem: CoercibleToMap<SortToExprMap<ElemSort, Name>, Name>): Bool<Name>;
subsetOf(b: SMTSet<Name, ElemSort>): Bool<Name>;

}

/**
* Defines the expression type of the body of a quantifier expression
*
Expand Down

0 comments on commit 4be4067

Please sign in to comment.