diff --git a/x/logic/interpreter/registry.go b/x/logic/interpreter/registry.go index d06eada8..63833c2d 100644 --- a/x/logic/interpreter/registry.go +++ b/x/logic/interpreter/registry.go @@ -55,7 +55,7 @@ var registry = map[string]any{ "current_output/1": engine.CurrentOutput, "set_input/1": engine.SetInput, "set_output/1": engine.SetOutput, - "open/4": engine.Open, + "open/4": predicate.Open, "close/2": engine.Close, "flush_output/1": engine.FlushOutput, "stream_property/2": engine.StreamProperty, diff --git a/x/logic/predicate/file.go b/x/logic/predicate/file.go index 11643e1e..a6222264 100644 --- a/x/logic/predicate/file.go +++ b/x/logic/predicate/file.go @@ -3,6 +3,7 @@ package predicate import ( "context" "fmt" + "os" "reflect" "sort" @@ -57,6 +58,84 @@ func SourceFile(vm *engine.VM, file engine.Term, cont engine.Cont, env *engine.E return engine.Delay(promises...) } +// ioMode describes what operations you can perform on the stream. +type ioMode int + +const ( + // ioModeRead means you can read from the stream. + ioModeRead = ioMode(os.O_RDONLY) + // ioModeWrite means you can write to the stream. + ioModeWrite = ioMode(os.O_CREATE | os.O_WRONLY) + // ioModeAppend means you can append to the stream. + ioModeAppend = ioMode(os.O_APPEND) | ioModeWrite +) + +var ( + atomRead = engine.NewAtom("read") + atomWrite = engine.NewAtom("write") + atomAppend = engine.NewAtom("append") +) + +func (m ioMode) Term() engine.Term { + return [...]engine.Term{ + ioModeRead: atomRead, + ioModeWrite: atomWrite, + ioModeAppend: atomAppend, + }[m] +} + +// Open opens SourceSink in mode and unifies with stream. +func Open(vm *engine.VM, sourceSink, mode, stream, options engine.Term, k engine.Cont, env *engine.Env) *engine.Promise { + var name string + switch s := env.Resolve(sourceSink).(type) { + case engine.Variable: + return engine.Error(fmt.Errorf("open/4: source cannot be a variable")) + case engine.Atom: + name = s.String() + default: + return engine.Error(fmt.Errorf("open/4: invalid domain for source, should be an atom, got %T", s)) + } + + var streamMode ioMode + switch m := env.Resolve(mode).(type) { + case engine.Variable: + return engine.Error(fmt.Errorf("open/4: streamMode cannot be a variable")) + case engine.Atom: + var ok bool + streamMode, ok = map[engine.Atom]ioMode{ + atomRead: ioModeRead, + atomWrite: ioModeWrite, + atomAppend: ioModeAppend, + }[m] + if !ok { + return engine.Error(fmt.Errorf("open/4: invalid open mode (read | write | append)")) + } + default: + return engine.Error(fmt.Errorf("open/4: invalid domain for open mode, should be an atom, got %T", m)) + } + + if _, ok := env.Resolve(stream).(engine.Variable); !ok { + return engine.Error(fmt.Errorf("open/4: stream can only be a variable, got %T", env.Resolve(stream))) + } + + if streamMode != ioModeRead { + return engine.Error(fmt.Errorf("open/4: only read mode is allowed here")) + } + + f, err := vm.FS.Open(name) + if err != nil { + return engine.Error(fmt.Errorf("open/4: couldn't open stream: %w", err)) + } + s := engine.NewInputTextStream(f) + + iter := engine.ListIterator{List: options, Env: env} + for iter.Next() { + return engine.Error(fmt.Errorf("open/4: options is not allowed here")) + } + + return engine.Unify(vm, stream, s, k, env) +} + func getLoadedSources(vm *engine.VM) map[string]struct{} { loadedField := reflect.ValueOf(vm).Elem().FieldByName("loaded").MapKeys() loaded := make(map[string]struct{}, len(loadedField)) diff --git a/x/logic/predicate/file_test.go b/x/logic/predicate/file_test.go index b78ee71c..79198399 100644 --- a/x/logic/predicate/file_test.go +++ b/x/logic/predicate/file_test.go @@ -4,6 +4,7 @@ package predicate import ( goctx "context" "fmt" + fs2 "io/fs" "net/url" "testing" "time" @@ -185,3 +186,195 @@ func TestSourceFile(t *testing.T) { } }) } + +func TestOpen(t *testing.T) { + Convey("Given a test cases", t, func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cases := []struct { + files map[string][]byte + program string + query string + wantResult []types.TermResults + wantError error + wantSuccess bool + }{ + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, read, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantResult: []types.TermResults{{ + "C": "d", + }}, + wantSuccess: true, + }, + { + files: map[string][]byte{ + "file": []byte("Hey"), + }, + program: "get_first_char(C) :- open(file, read, Stream, []), get_char(Stream, C).", + query: `get_first_char(C).`, + wantResult: []types.TermResults{{ + "C": "'H'", + }}, + wantSuccess: true, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(File, write, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: source cannot be a variable"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(34, write, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: invalid domain for source, should be an atom, got engine.Integer"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, write, stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: stream can only be a variable, got engine.Atom"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, 45, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: invalid domain for open mode, should be an atom, got engine.Integer"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, foo, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: invalid open mode (read | write | append)"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, write, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: only read mode is allowed here"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, append, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: only read mode is allowed here"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file2, read, Stream, _), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: couldn't open stream: read file2: path not found"), + wantSuccess: false, + }, + { + files: map[string][]byte{ + "file": []byte("dumb(dumber)."), + }, + program: "get_first_char(C) :- open(file, read, Stream, [option1]), get_char(Stream, C).", + query: `get_first_char(C).`, + wantError: fmt.Errorf("open/4: options is not allowed here"), + wantSuccess: false, + }, + } + for nc, tc := range cases { + Convey(fmt.Sprintf("Given the query #%d: %s", nc, tc.query), func() { + Convey("and a mocked file system", func() { + uri, _ := url.Parse("file://dump.pl") + mockedFS := testutil.NewMockFS(ctrl) + mockedFS.EXPECT().Open(gomock.Any()).AnyTimes().DoAndReturn(func(name string) (fs.VirtualFile, error) { + for key, bytes := range tc.files { + if key == name { + return fs.NewVirtualFile(bytes, uri, time.Now()), nil + } + } + return fs.VirtualFile{}, &fs2.PathError{ + Op: "read", + Path: "file2", + Err: fmt.Errorf("path not found"), + } + }) + Convey("and a context", func() { + db := tmdb.NewMemDB() + stateStore := store.NewCommitMultiStore(db) + ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) + + Convey("and a vm", func() { + interpreter := testutil.NewComprehensiveInterpreterMust(ctx) + interpreter.FS = mockedFS + interpreter.Register4(engine.NewAtom("open"), Open) + + err := interpreter.Compile(ctx, tc.program) + So(err, ShouldBeNil) + + Convey("When the predicate is called", func() { + sols, err := interpreter.QueryContext(ctx, tc.query) + + Convey("Then the error should be nil", func() { + So(err, ShouldBeNil) + So(sols, ShouldNotBeNil) + + Convey("and the bindings should be as expected", func() { + var got []types.TermResults + for sols.Next() { + m := types.TermResults{} + err := sols.Scan(m) + So(err, ShouldBeNil) + + got = append(got, m) + } + if tc.wantError != nil { + So(sols.Err(), ShouldNotBeNil) + So(sols.Err().Error(), ShouldEqual, tc.wantError.Error()) + } else { + So(sols.Err(), ShouldBeNil) + + if tc.wantSuccess { + So(len(got), ShouldBeGreaterThan, 0) + So(len(got), ShouldEqual, len(tc.wantResult)) + for iGot, resultGot := range got { + for varGot, termGot := range resultGot { + So(testutil.ReindexUnknownVariables(termGot), ShouldEqual, tc.wantResult[iGot][varGot]) + } + } + } else { + So(len(got), ShouldEqual, 0) + } + } + }) + }) + }) + }) + }) + }) + }) + } + }) +} diff --git a/x/logic/testutil/logic.go b/x/logic/testutil/logic.go index 65147b12..482c132f 100644 --- a/x/logic/testutil/logic.go +++ b/x/logic/testutil/logic.go @@ -49,6 +49,7 @@ func NewComprehensiveInterpreterMust(ctx context.Context) (i *prolog.Interpreter i.Register1(engine.NewAtom("current_output"), engine.CurrentOutput) i.Register1(engine.NewAtom("current_input"), engine.CurrentInput) i.Register2(engine.NewAtom("put_char"), engine.PutChar) + i.Register2(engine.NewAtom("get_char"), engine.GetChar) i.Register3(engine.NewAtom("write_term"), engine.WriteTerm) err := i.Compile(ctx, bootstrap.Bootstrap())