Skip to content

Commit

Permalink
Add support for completion
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonsilva committed May 30, 2024
1 parent df665a2 commit 4b5cab8
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 23 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.1.1/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [0.3.0] - 2024-05-30

### Added

- Added support for completion requests which you can use to query their latest model
[codestral](https://mistral.ai/news/codestral/).
See [this example](https://github.com/wilsonsilva/mistral/blob/0.3.0/examples/code_completion.rb) to get started.

## [0.2.0] - 2024-05-23

### Added
Expand All @@ -26,5 +34,6 @@ Ports [mistralai/client-python#86](https://github.com/mistralai/client-python/pu
- Initial release. Feature parity with `v0.1.8` of the
[mistralai/client-python](https://github.com/mistralai/client-python)

[0.3.0]: https://github.com/wilsonsilva/mistral/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/wilsonsilva/mistral/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/wilsonsilva/mistral/compare/28e7c9...v0.1.0
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,17 @@ end

In the [`examples`](https://github.com/wilsonsilva/mistral/tree/main/examples) folder, you will find how to do:

| File Name | Description |
|--------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------|
| [`chat_no_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chat_no_streaming.rb) | How to use the chat endpoint without streaming |
| [`chat_with_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chat_with_streaming.rb) | How to use the chat endpoint with streaming |
| [`chatbot_with_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chatbot_with_streaming.rb) | A simple interactive chatbot using streaming |
| [`embeddings.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/embeddings.rb) | How to use the embeddings endpoint |
| [`function_calling.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/function_calling.rb) | How to call functions using the chat endpoint |
| [`json_format.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/json_format.rb) | How to request and parse JSON responses from the chat endpoint |
| [`list_models.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/list_models.rb) | How to list available models |
| File Name | Description |
|--------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------|
| [`chat_no_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chat_no_streaming.rb) | How to use the chat endpoint without streaming |
| [`chat_with_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chat_with_streaming.rb) | How to use the chat endpoint with streaming |
| [`chatbot_with_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/chatbot_with_streaming.rb) | A simple interactive chatbot using streaming |
| [`code_completion.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/code_completion.rb) | How to perform a code completion |
| [`completion_with_streaming.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/completion_with_streaming.rb) | How to perform a code completion with streaming |
| [`embeddings.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/embeddings.rb) | How to use the embeddings endpoint |
| [`function_calling.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/function_calling.rb) | How to call functions using the chat endpoint |
| [`json_format.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/json_format.rb) | How to request and parse JSON responses from the chat endpoint |
| [`list_models.rb`](https://github.com/wilsonsilva/mistral/blob/main/examples/list_models.rb) | How to list available models |

## 🔨 Development

Expand Down
9 changes: 5 additions & 4 deletions examples/chatbot_with_streaming.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
require 'mistral'

MODEL_LIST = %w[
mistral-tiny
mistral-small
mistral-medium
mistral-tiny-latest
mistral-small-latest
mistral-medium-latest
codestral-latest
].freeze
DEFAULT_MODEL = 'mistral-small'
DEFAULT_MODEL = 'mistral-small-latest'
DEFAULT_TEMPERATURE = 0.7
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
# A hash of all commands and their arguments, used for tab completion.
Expand Down
24 changes: 24 additions & 0 deletions examples/code_completion.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env ruby
# frozen_string_literal: true

require 'bundler/setup'
require 'dotenv/load'
require 'mistral'

api_key = ENV.fetch('MISTRAL_API_KEY')
client = Mistral::Client.new(api_key: api_key)

prompt = 'def fibonacci(n: int):'
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

response = client.completion(
model: 'codestral-latest',
prompt: prompt,
suffix: suffix
)

print <<~COMPLETION
#{prompt}
#{response.choices[0].message.content}
#{suffix}
COMPLETION
24 changes: 24 additions & 0 deletions examples/completion_with_streaming.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env ruby
# frozen_string_literal: true

require 'bundler/setup'
require 'dotenv/load'
require 'mistral'

api_key = ENV.fetch('MISTRAL_API_KEY')
client = Mistral::Client.new(api_key: api_key)

prompt = 'def fibonacci(n: int):'
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

print(prompt)

client.completion_stream(
model: 'codestral-latest',
prompt: prompt,
suffix: suffix
).each do |chunk|
print(chunk.choices[0].delta.content) unless chunk.choices[0].delta.content.nil?
end

print(suffix)
67 changes: 67 additions & 0 deletions lib/mistral/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,73 @@ def list_models
raise Mistral::Error.new(message: 'No response received')
end

# A completion endpoint that returns a single response.
#
# @param model [String] model the name of the model to get completion with, e.g. codestral-latest
# @param prompt [String] the prompt to complete
# @param suffix [String, nil] the suffix to append to the prompt for fill-in-the-middle completion
# @param temperature [Float, nil] temperature the temperature to use for sampling, e.g. 0.5.
# @param max_tokens [Integer, nil] the maximum number of tokens to generate, e.g. 100. Defaults to nil.
# @param top_p [Float, nil] the cumulative probability of tokens to generate, e.g. 0.9. Defaults to nil.
# @param random_seed [Integer, nil] the random seed to use for sampling, e.g. 42. Defaults to nil.
# @param stop [Array<String>, nil] a list of tokens to stop generation at, e.g. ['/n/n']
# @return [ChatCompletionResponse] a response object containing the generated text.
#
def completion(
model:,
prompt:,
suffix: nil,
temperature: nil,
max_tokens: nil,
top_p: nil,
random_seed: nil,
stop: nil
)
request = make_completion_request(
prompt:, model:, suffix:, temperature:, max_tokens:, top_p:, random_seed:, stop:
)
single_response = request('post', 'v1/fim/completions', json: request, stream: false)

single_response.each do |response|
return ChatCompletionResponse.new(**response)
end

raise Error, 'No response received'
end

# An asynchronous completion endpoint that streams responses.
#
# @param model [String] model the name of the model to get completions with, e.g. codestral-latest
# @param prompt [String] the prompt to complete
# @param suffix [String, nil] the suffix to append to the prompt for fill-in-the-middle completion
# @param temperature [Float, nil] temperature the temperature to use for sampling, e.g. 0.5.
# @param max_tokens [Integer, nil] the maximum number of tokens to generate, e.g. 100. Defaults to nil.
# @param top_p [Float, nil] the cumulative probability of tokens to generate, e.g. 0.9. Defaults to nil.
# @param random_seed [Integer, nil] the random seed to use for sampling, e.g. 42. Defaults to nil.
# @param stop [Array<String>, nil] a list of tokens to stop generation at, e.g. ['/n/n']
# @return [Enumerator<ChatCompletionStreamResponse>] a generator that yields response objects containing the
# generated text.
#
def completion_stream(
model:,
prompt:,
suffix: nil,
temperature: nil,
max_tokens: nil,
top_p: nil,
random_seed: nil,
stop: nil
)
request = make_completion_request(
prompt:, model:, suffix:, temperature:, max_tokens:, top_p:, random_seed:, stop:, stream: true
)
response = request('post', 'v1/fim/completions', json: request, stream: true)

response.lazy.map do |json_streamed_response|
ChatCompletionStreamResponse.new(**json_streamed_response)
end
end

private

def request(method, path, json: nil, stream: false, attempt: 1)
Expand Down
51 changes: 51 additions & 0 deletions lib/mistral/client_base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,57 @@ def parse_messages(messages)
parsed_messages
end

def make_completion_request(
prompt:,
model: nil,
suffix: nil,
temperature: nil,
max_tokens: nil,
top_p: nil,
random_seed: nil,
stop: nil,
stream: false
)
request_data = {
'prompt' => prompt,
'suffix' => suffix,
'model' => model,
'stream' => stream
}

request_data['stop'] = stop unless stop.nil?

if model.nil?
raise Error.new(message: 'model must be provided') if @default_model.nil?

request_data['model'] = @default_model
else
request_data['model'] = model
end

request_data.merge!(
build_sampling_params(
temperature: temperature,
max_tokens: max_tokens,
top_p: top_p,
random_seed: random_seed
)
)

@logger.debug("Completion request: #{request_data}")

request_data
end

def build_sampling_params(max_tokens: nil, random_seed: nil, temperature: nil, top_p: nil)
params = {}
params['temperature'] = temperature unless temperature.nil?
params['max_tokens'] = max_tokens unless max_tokens.nil?
params['top_p'] = top_p unless top_p.nil?
params['random_seed'] = random_seed unless random_seed.nil?
params
end

def make_chat_request(
messages:,
model: nil,
Expand Down
12 changes: 6 additions & 6 deletions test/test_chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_chat
body: {
messages: [{ role: 'user', content: 'What is the best French cheese?' }],
safe_prompt: false,
model: 'mistral-small',
model: 'mistral-small-latest',
stream: false
}.to_json,
headers: {
Expand All @@ -27,7 +27,7 @@ def test_chat
.to_return(status: 200, body: mock_chat_response_payload, headers: {})

result = @client.chat(
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
Mistral::ChatMessage.new(role: 'user', content: 'What is the best French cheese?')
]
Expand All @@ -44,7 +44,7 @@ def test_chat
body: {
messages: [{ role: 'user', content: 'What is the best French cheese?' }],
safe_prompt: false,
model: 'mistral-small',
model: 'mistral-small-latest',
stream: false
},
times: 1
Expand All @@ -64,7 +64,7 @@ def test_chat_streaming
{ role: 'user', content: 'What is the best French cheese?' }
],
safe_prompt: false,
model: 'mistral-small',
model: 'mistral-small-latest',
stream: true
}.to_json,
headers: {
Expand All @@ -78,7 +78,7 @@ def test_chat_streaming
.to_return(status: 200, body: mock_chat_response_streaming_payload.join, headers: {})

chat_stream_result = @client.chat_stream(
model: 'mistral-small',
model: 'mistral-small-latest',
messages: [
Mistral::ChatMessage.new(role: 'user', content: 'What is the best French cheese?')
]
Expand All @@ -99,7 +99,7 @@ def test_chat_streaming
{ role: 'user', content: 'What is the best French cheese?' }
],
safe_prompt: false,
model: 'mistral-small',
model: 'mistral-small-latest',
stream: true
}.to_json,
times: 1
Expand Down
Loading

0 comments on commit 4b5cab8

Please sign in to comment.