Skip to content

Commit

Permalink
fix coloring of last n_batch of prompt, and refactor line input (#221)
Browse files Browse the repository at this point in the history
* fix coloring of last `n_batch` of prompt, and refactor line input
* forgot the newline that needs to be sent to the model
* (per #283) try to force flush of color reset in SIGINT handler
  • Loading branch information
bitRAKE authored Mar 19, 2023
1 parent 2456837 commit 5c19c70
Showing 1 changed file with 24 additions and 34 deletions.
58 changes: 24 additions & 34 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <string>
#include <vector>
Expand Down Expand Up @@ -997,11 +998,6 @@ int main(int argc, char ** argv) {
break;
}
}

// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int) embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}
}

// display text
Expand All @@ -1011,6 +1007,10 @@ int main(int argc, char ** argv) {
}
fflush(stdout);
}
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
Expand All @@ -1032,43 +1032,33 @@ int main(int argc, char ** argv) {
}

// currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
while (another_line) {
fflush(stdout);
char buf[256] = {0};
int n_read;
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
// presumable empty line, consume the newline
std::ignore = scanf("%*c");
n_read=0;
}
if (params.use_color) printf(ANSI_COLOR_RESET);

if (n_read > 0 && buf[n_read-1]=='\\') {
another_line = true;
buf[n_read-1] = '\n';
buf[n_read] = 0;
} else {
do {
std::getline(std::cin, line);

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 19, 2023

Can we switch back to using scanf/<cstdio> here? I was using the FILE * abstraction to read from terminal or a TCP socket in #278 @bitRAKE @ggerganov

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 19, 2023

To clarify, I did some quick search and it appears C++ does not provide a standard abstraction over file descriptors: https://stackoverflow.com/a/5253726

This comment has been minimized.

Copy link
@Green-Sky

Green-Sky Mar 19, 2023

Collaborator

hm, I know this sounds like alot of work. but what if you implement a https://en.cppreference.com/w/cpp/io/basic_streambuf
for the tcp stream instead?

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 19, 2023

I'd be fine doing it if there was a C++ equivalent of fdopen, which allowed me to pass a file descriptor and get an object with the std::in interface.

But implementing basic_streambuf is quite a bit outside of my comfort zone, I'm not familiar with C++ enough to do that confidently.

This comment has been minimized.

Copy link
@Green-Sky

Green-Sky Mar 19, 2023

Collaborator

oh wow, i did not see your second message before sent mine. but yes, the stackoverflow article basically says what i suggested. :)

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 19, 2023

I asked ChatGPT to do it for me and it sent me the following solution:

#include <iostream>
#include <unistd.h>

class PosixStream : public std::istream {
  public:
    PosixStream(int fd) : std::istream(&buf), buf(fd) {}

  private:
    class PosixStreamBuf : public std::streambuf {
      public:
        PosixStreamBuf(int fd) : fd(fd) {}

      protected:
        virtual int_type underflow() {
          if (gptr() < egptr()) {
            return traits_type::to_int_type(*gptr());
          }

          ssize_t num_read = read(fd, buffer, BUFFER_SIZE);
          if (num_read <= 0) {
            return traits_type::eof();
          }

          setg(buffer, buffer, buffer + num_read);
          return traits_type::to_int_type(*gptr());
        }

      private:
        static const int BUFFER_SIZE = 1024;
        int fd;
        char buffer[BUFFER_SIZE + 1];
    };

    PosixStreamBuf buf;
};
int main() {
  int fd = open("file.txt", O_RDONLY);
  PosixStream stream(fd);

  // Use stream as if it were std::in
  std::string line;
  while (std::getline(stream, line)) {
    std::cout << line << std::endl;
  }

  close(fd);
  return 0;
}

Seems like it would work. Should I use it?

This comment has been minimized.

Copy link
@Green-Sky

Green-Sky Mar 19, 2023

Collaborator

I mean, I don't know raw stl container that much to say for sure, but it looks correct. 😅

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 19, 2023

OK the ChatGPT solution appears to work with a few tweaks, I've rebased the PR with these changes. I had tried to SO solution (__gnu_cxx::stdio_filebuf) but it is not part of a standard, didn't compile on Mac.

This comment has been minimized.

Copy link
@Green-Sky

Green-Sky Mar 19, 2023

Collaborator

yep cool

This comment has been minimized.

Copy link
@bitRAKE

bitRAKE Mar 20, 2023

Author Contributor

This seems to work for me - what am I missing?

#include <iostream>
#include <fstream>
#include <string>

using namespace std;

int main(int argc, char** argv) {
    istream* input;
    string inputType;
    for (int i = 1; i < argc; i++) {
        if (string(argv[i]) == "-i" && i < argc - 1) {
            string inputType = argv[++i];
            if (inputType == "stdin") input = &cin;
            else if (inputType == "file") input = new ifstream("file.txt");
            else if (inputType == "tcp") input = nullptr;
            else {
                cerr << "Unknown input type " << inputType << endl;
                return 1;
            }
        }
    }
    string line;
    while (getline(*input, line)) {
        cout << line << endl;
    }
    // Clean up input stream if necessary
    if (inputType == "file" || inputType == "tcp") {
        delete input;
    }
}

Also, kind of from ChatGPT. ;)

This comment has been minimized.

Copy link
@tarruda

tarruda Mar 20, 2023

@bitRAKE the problem was wrapping a file descriptor in a istream, but ChatGPT provided a working solution which I've added to the PR branch. Already rebased/adapted on top of your changes 😄

if (line.empty() || line.back() != '\\') {
another_line = false;
buf[n_read] = '\n';
buf[n_read+1] = 0;
}

std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);

remaining_tokens -= line_inp.size();
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

input_noecho = true; // do not echo this again
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}

is_interacting = false;
remaining_tokens -= line_inp.size();

input_noecho = true; // do not echo this again
}
is_interacting = false;
}

// end of text token
Expand Down

0 comments on commit 5c19c70

Please sign in to comment.