Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heuristic algorithm for speculative #3006

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));

// how many tokens to draft each time
const int n_draft = params.n_draft;
int n_draft = params.n_draft;

int n_predict = 0;
int n_drafted = 0;
Expand Down Expand Up @@ -116,6 +116,8 @@ int main(int argc, char ** argv) {

// sample from the drafted tokens if any
int i_dft = 0;
bool all_accepted = false;

while (true) {
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);

Expand All @@ -141,6 +143,9 @@ int main(int argc, char ** argv) {
++n_past_dft;
++i_dft;

if (i_dft == (int) drafted.size()) {
all_accepted = true;
}
continue;
}

Expand All @@ -154,6 +159,14 @@ int main(int argc, char ** argv) {
break;
}

if (drafted.size() > 0 && all_accepted) {
n_draft += 2;
LOG("all drafted tokens accepted, n_draft = %d\n", n_draft);
} else {
n_draft -= 1;
LOG("drafted token rejected, n_draft = %d\n", n_draft);
}
Copy link
Contributor

@bobqianic bobqianic Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does n_draft go up by 2 when all drafted tokens are accepted, but decrease by 1 when a drafted token is rejected? Is there a more efficient algorithm to handle this? The current approach seems similar to a simplified version of TCP Friendly Rate Control algorithm.

Copy link
Contributor Author

@leng-yue leng-yue Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty much borrowed from Hugging Face's code. We could fine-tune it by tweaking some parameters. Since getting all tokens right is challenging, it seems reasonable to bump up n_draft by 2 when everything aligns and decrease it by 1 otherwise.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_draft should be constrained to not go below 2 for example and

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I got after restricting the minimum n_draft to 2:

Outputs
 // Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage:

// 1. Add all nodes to the graph, and add edges between them with distances/weights.
// 2. Call dijkstra(start_node) to get shortest paths from start_node to all other nodes.
//    It returns a map of <node, distance>.
// 3. Use path_exists(node) to check if there is a path between start_node and node.
// 4. Use get_shortest_distance(node) to get the shortest distance from start_node to node.
//    If no path exists, it returns -1.
// 5. Use reconstruct_path(node) to reconstruct the shortest path between start_node and node.
//    It returns a vector of nodes that forms the path (including both start_node and node).
//    If no path exists, it returns an empty vector.

#include <iostream>
#include <vector>
#include <map>
#include <set>
#include <queue>
using namespace std;

class Graph {
public:
        struct Edge {
                int node;
                int distance;
        };

private:
        // adjacency list of the graph (to store all edges)
        vector<vector<Edge>> adj_list;

        // map to keep track of visited nodes during dijkstra()
        map<int, bool> visited;

        // map to keep track of distances from start node to other nodes during dijkstra()
        map<int, int> distance;

public:
        Graph(int n) {
                adj_list.resize(n);
        }

        void add_edge(int src, int dest, int dist) {
                Edge edge = {dest, dist};
                adj_list[src].push_back(edge);
        }

        map<int, int> dijkstra(int start) {
                // initialize all distances as -1 (inf). 0 is the starting node.
                for (int i = 0; i < adj_list.size(); ++i) {
                        distance[i] = -1;
                }
                distance[start] = 0;

                // create a priority queue to get the minimum distance node,
                // visited helps avoid processing same node again
                priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pq;
                pq.push({0, start});

                while (!pq.empty()) {
                        // get the node with minimum distance from priority queue
                        auto top = pq.top();
                        int u = top.second;
                        pq.pop();

                        if (visited[u]) continue; // skip already visited nodes

                        // mark the node as visited
                        visited[u] = true;

                        // process all neighbours of the current node 'u'
                        for (auto edge : adj_list[u]) {
                                int v = edge.node;
                                int dist_v = distance[u] + edge.distance; // distance from 'u' to 'v'

                                // check if there is shorted path to 'v' through 'u'.
                                if (distance[v] == -1 || distance[v] > dist_v) {
                                        // update the distance of 'v' only if it is not in visited,
                                        // or can be reached with shorter distance from 'u'.
                                        distance[v] = dist_v;
                                        pq.push({dist_v, v}); // add 'v' to priority queue
                                }
                        }
                }

                return distance;
        }

        bool path_exists(int dest) {
                // return true if the destination node is in visited map.
                return visited[dest];
        }

        int get_shortest_distance(int dest) {
                // return the shortest distance from start to destination node.
                return distance[dest];
        }

        vector<int> reconstruct_path(int dest) {
                // vector to store the path
                vector<int> path;

                if (!visited[dest]) return path; // no path exists from start node to destination node.

                for (int v = dest; v != 0; v = distance[v]) {
                        path.push_back(v);
                }
                reverse(path.begin(), path.end());

                return path;
        }
};

int main() {
        // create a graph with 9 nodes (0 to 8)
        Graph g(9);

        g.add_edge(0, 1, 4);
        g.add_edge(0, 7, 8);
        g.add_edge(1, 2, 8);
        g.add_edge(1, 7, 11);
        g.add_edge(2, 3, 7);
        g.add_edge(2, 8, 2);
        g.add_edge(2, 5, 4);
        g.add_edge(3, 4, 9);
        g.add_edge(3, 5, 14);
        g.add_edge(4, 5, 10);
        g.add_edge(5, 6, 2);
        g.add_edge(6, 7, 1);
        g.add_edge(6, 8, 6);
        g.add_edge(7, 8, 7);

        // call dijkstra() with start node as 0
        auto distance = g.dijkstra(0);

        for (int i = 1; i < distance.size(); ++i) {
                cout << "Distance from 0 to " << i << ": ";
                if (!g.path_exists(i)) cout << "No path exists.";
                else cout << g.get_shortest_distance(i);
                cout << endl;
        }

        // print the shortest path from 0 to 8
        auto path = g.reconstruct_path(8);
        for (int node : path) {
                cout << node << " ";
        }
        cout << endl;

        return 0;
}

encoded   24 tokens in    0.418 seconds, speed:   57.356 t/s
decoded 1546 tokens in  109.633 seconds, speed:   14.102 t/s

n_draft   = 75
n_predict = 1546
n_drafted = 2261
n_accept  = 1263
accept    = 55.860%

draft:

llama_print_timings:        load time =   725.46 ms
llama_print_timings:      sample time =  3855.13 ms /     1 runs   ( 3855.13 ms per token,     0.26 tokens per second)
llama_print_timings: prompt eval time =   101.84 ms /    24 tokens (    4.24 ms per token,   235.67 tokens per second)
llama_print_timings:        eval time = 47541.13 ms /  2529 runs   (   18.80 ms per token,    53.20 tokens per second)
llama_print_timings:       total time = 110050.98 ms

target:

llama_print_timings:        load time =  2122.55 ms
llama_print_timings:      sample time =   501.15 ms /  1546 runs   (    0.32 ms per token,  3084.92 tokens per second)
llama_print_timings: prompt eval time = 54614.28 ms /  2495 tokens (   21.89 ms per token,    45.68 tokens per second)
llama_print_timings:        eval time =  2831.66 ms /    72 runs   (   39.33 ms per token,    25.43 tokens per second)
llama_print_timings:       total time = 110779.95 ms

Full log: speculative.139912996806656.log
It shows that n_draft never goes under 10 in this case.


As a comparison, this one doesn't include heuristic algorithm:

Outputs
 // Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage:

// 1. Add all nodes to the graph, and add edges between them with distances/weights.
// 2. Call dijkstra(start_node) to get shortest paths from start_node to all other nodes.
//    It returns a map of <node, distance>.
// 3. Use path_exists(node) to check if there is a path between start_node and node.
// 4. Use get_shortest_distance(node) to get the shortest distance from start_node to node.
//    If no path exists, it returns -1.
// 5. Use reconstruct_path(node) to get the shortest path from start_node to node.
//    It returns a vector of nodes that make up the path.

#include <iostream>
#include <vector>
#include <map>
#include <set>
#include <queue>
using namespace std;

class Graph {
public:
        struct Edge {
                int node, distance;
        };

        // Adds a directed edge between "from" and "to" with the given "distance".
        void add_edge(int from, int to, int distance) {
                edges[from].push_back({to, distance});
        }

        // Returns true if there is an edge between "from" and "to", false otherwise.
        bool has_edge(int from, int to) const {
                for (const Edge& e : edges.at(from)) {
                        if (e.node == to) return true;
                }
                return false;
        }

        // Returns the distance between "from" and "to". If there is no edge, returns -1.
        int get_distance(int from, int to) const {
                for (const Edge& e : edges.at(from)) {
                        if (e.node == to) return e.distance;
                }
                return -1;
        }

        // Returns a map of <node, distance> representing the shortest paths from "start_node" to all other nodes.
        map<int, int> dijkstra(int start_node) const {
                priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pq; // (distance, node)
                map<int, bool> visited;
                map<int, int> distances; // <node, distance>

                pq.push({0, start_node});
                distances[start_node] = 0;

                while (!pq.empty()) {
                        auto top = pq.top();
                        int node = top.second;
                        int distance = top.first;
                        pq.pop();

                        if (visited[node]) continue; // already visited this node
                        visited[node] = true;

                        // update distances of neighbors
                        for (const Edge& edge : edges.at(node)) {
                                int neighbor_node = edge.node;
                                int new_distance = distance + edge.distance;
                                if (!distances.count(neighbor_node) || distances[neighbor_node] > new_distance) {
                                        pq.push({new_distance, neighbor_node});
                                        distances[neighbor_node] = new_distance;
                                }
                        }
                }

                return distances;
        }

        // Returns true if there is a path between "start_node" and "node", false otherwise.
        bool path_exists(int start_node, int node) const {
                map<int, bool> visited;
                queue<int> q;
                q.push(start_node);

                while (!q.empty()) {
                        int current = q.front();
                        q.pop();

                        if (current == node) return true;
                        visited[current] = true;

                        for (const Edge& edge : edges.at(current)) {
                                int neighbor_node = edge.node;
                                if (!visited[neighbor_node]) q.push(neighbor_node);
                        }
                }

                return false;
        }

        // Returns the shortest distance from "start_node" to "node". If there is no path, returns -1.
        int get_shortest_distance(int start_node, int node) const {
                map<int, bool> visited;
                queue<pair<int, int>> q; // (distance, node)
                q.push({0, start_node});

                while (!q.empty()) {
                        auto top = q.front();
                        int distance = top.first;
                        int current = top.second;
                        q.pop();

                        if (current == node) return distance;
                        visited[current] = true;

                        for (const Edge& edge : edges.at(current)) {
                                int neighbor_node = edge.node;
                                int new_distance = distance + edge.distance;
                                if (!visited[neighbor_node]) q.push({new_distance, neighbor_node});
                        }
                }

                return -1;
        }

        // Returns the shortest path from "start_node" to "node". If there is no path, returns an empty vector.
        vector<int> reconstruct_path(int start_node, int node) const {
                map<int, bool> visited;
                queue<pair<int, int>> q; // (distance, node)
                q.push({0, start_node});

                // parents[i] is the parent of i in the shortest path from start_node to i.
                map<int, int> parents;

                while (!q.empty()) {
                        auto top = q.front();
                        int distance = top.first;
                        int current = top.second;
                        q.pop();

                        if (current == node) break;
                        visited[current] = true;

                        for (const Edge& edge : edges.at(current)) {
                                int neighbor_node = edge.node;
                                int new_distance = distance + edge.distance;
                                if (!visited[neighbor_node]) {
                                        q.push({new_distance, neighbor_node});
                                        parents[neighbor_node] = current;
                                }
                        }
                }

                vector<int> path;
                for (int n = node; n != start_node; n = parents.at(n)) {
                        path.push_back(n);
                }
                path.push_back(start_node);
                reverse(path.begin(), path.end());
                return path;
        }

private:
        // Map of <node, list of edges> representing the graph.
        map<int, vector<Edge>> edges;
};

int main() {
        Graph g;
        g.add_edge(0, 1, 2);
        g.add_edge(0, 3, 4);
        g.add_edge(1, 2, 5);
        g.add_edge(1, 3, 6);
        g.add_edge(2, 3, 7);
        g.add_edge(2, 4, 8);
        g.add_edge(3, 4, 9);

        map<int, int> distances = g.dijkstra(0);
        for (auto [node, distance] : distances) {
                cout << "Distance from 0 to " << node << ": " << distance << endl;
        }

        cout << boolalpha;
        cout << "Path exists between 0 and 4: " << g.path_exists(0, 4) << endl;
        cout << "Shortest distance from 0 to 4: " << g.get_shortest_distance(0, 4) << endl;
        vector<int> path = g.reconstruct_path(0, 4);
        for (int node : path) {
                cout << node << " ";
        }
        cout << endl;
}


encoded   24 tokens in    0.419 seconds, speed:   57.252 t/s
decoded 2071 tokens in  126.799 seconds, speed:   16.333 t/s

n_draft   = 16
n_predict = 2071
n_drafted = 2400
n_accept  = 1749
accept    = 72.875%

draft:

llama_print_timings:        load time =   723.91 ms
llama_print_timings:      sample time =  4027.20 ms /     1 runs   ( 4027.20 ms per token,     0.25 tokens per second)
llama_print_timings: prompt eval time =   101.86 ms /    24 tokens (    4.24 ms per token,   235.61 tokens per second)
llama_print_timings:        eval time = 51436.25 ms /  2638 runs   (   19.50 ms per token,    51.29 tokens per second)
llama_print_timings:       total time = 127217.59 ms

target:

llama_print_timings:        load time =  2103.87 ms
llama_print_timings:      sample time =   687.69 ms /  2071 runs   (    0.33 ms per token,  3011.53 tokens per second)
llama_print_timings: prompt eval time = 67933.80 ms /  2687 tokens (   25.28 ms per token,    39.55 tokens per second)
llama_print_timings:        eval time =  2321.62 ms /    58 runs   (   40.03 ms per token,    24.98 tokens per second)
llama_print_timings:       total time = 127945.07 ms

Full log: speculative.140657253044224.log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models I am using are: codellama-7b.Q4_K_M.gguf and codellama-34b.Q4_K_M.gguf.

Copy link
Owner

@ggerganov ggerganov Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, very strange. After ~140 tokens of identical results, the target model samples a different token with 100% probability:

image

What GPU backend do you use? Is this CUDA?

Edit: the 100% probability is actually expected since we are doing --top_k 1 sampling

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you try instead of using --top_k 1, to use --temp -1 and see if the problem persists.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using cuBLAS backend, and I got same output after changing --top_k 1 to --temp -1.
With heuristic: speculative-a.140629255983104.log
Without heuristic: speculative-b.140469514129408.log


if (n_predict > params.n_predict || has_eos) {
break;
}
Expand Down