-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMicroKanren.scala
108 lines (86 loc) · 3.15 KB
/
MicroKanren.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package microKanren
import scala.collection.immutable.{LinearSeq}
object MicroKanren {
type Goal = State => Stream
type Subst = Map[Var, Any]
case class Var(index: Int)
abstract trait Stream
case class MatureStream(state: State, stream: Stream) extends Stream
case class ImmatureStream(func: () => Stream) extends Stream
case class EmptyStream() extends Stream
case class State(s: Subst, c: Int)
case class UnificationException() extends Exception
def EmptyState(): State = State(Map(), 0)
def walk(u: Any, subst: Subst): Any = {
if (u.isInstanceOf[Var] && subst.contains(u.asInstanceOf[Var])) walk(subst((u.asInstanceOf[Var])), subst) else u
}
def extendSubst(x: Var, v: Any, s: Subst): Subst = s + (x -> v)
def unify(a: Any, b: Any, s: Subst): Subst = (walk(a, s), walk(b, s)) match {
case (Var(index1), Var(index2)) if (index1 == index2) => s
case (x: Var, v) => extendSubst(x, v, s)
case (v, x: Var) => extendSubst(x, v, s)
case (list1: List[Any], list2: List[Any]) if list1.length == list2.length =>
if (list1.isEmpty) s
val newS = unify(list1.head, list2.head, s)
unify((list1.tail), (list2.tail), newS)
case (u, v) if u == v => s
case (_, _) => throw UnificationException()
}
def call_fresh(f: Var => Goal): Goal = {
(state: State) => f(Var(state.c))(state.copy(c = state.c + 1))
}
def :=(a: Any, b: Any): Goal = {
(state: State) => try {
MatureStream(State(unify(a, b, state.s), state.c), EmptyStream())
} catch {
case UnificationException() => EmptyStream()
}
}
def mplus(s1: Stream, s2: Stream): Stream = (s1, s2) match {
case (EmptyStream(), _) => s2
case (ImmatureStream(func), _) => ImmatureStream(() => mplus(s2, func()))
case (MatureStream(state, stream), _) => MatureStream(state, mplus(stream, s2))
}
def bind(s1: Stream, g: Goal): Stream = s1 match {
case EmptyStream() => EmptyStream()
case ImmatureStream(func) => ImmatureStream(() => bind(func(), g))
case MatureStream(state, stream) => mplus(g(state), bind(stream, g))
}
def disj(g1: Goal, g2: Goal): Goal = {
(state: State) => mplus(g1(state), g2(state))
}
def conj(g1: Goal, g2: Goal): Goal = {
(state: State) => bind(g1(state), g2)
}
def zzz(g: Goal): Goal = {
(s: State) => ImmatureStream(() => g(s))
}
def conjMult(goals: Goal*): Goal = {
goals match {
case g :: Nil => zzz(g)
case g :: gs => conj(zzz(g), conjMult(gs: _*))
}
}
def disjMult(goals: Goal*): Goal = {
goals match {
case g :: Nil => zzz(g)
case g :: gs => disj(zzz(g), disjMult(gs: _*))
}
}
def pull(stream: Stream): Stream = stream match {
case ImmatureStream(func) => pull(func())
case _ => stream
}
def take_all(stream: Stream): List[State] = pull(stream) match {
case EmptyStream() => List()
case MatureStream(state, tailStream) => state :: take_all(tailStream)
}
def take_n(n: Int, stream: Stream): List[State] =
if (n <= 0) {List()}
else {
pull(stream) match {
case EmptyStream() => List()
case MatureStream(state, tailStream) => state :: take_n(n - 1, tailStream)
}
}
}