Skip to content

Commit

Permalink
feat!: token env is called PHARIA_AI_TOKEN
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Dec 10, 2024
1 parent d7ad3b8 commit 827b44c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
AA_API_TOKEN=
PHARIA_AI_TOKEN=
AA_BASE_URL=
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
run: cargo build
- name: Run tests
env:
AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }}
PHARIA_AI_TOKEN: ${{ secrets.PHARIA_AI_TOKEN }}
AA_BASE_URL: https://inference-api.product.pharia.com
run: cargo test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Avoid commiting AA_API_TOKEN
# Avoid commiting PHARIA_AI_TOKEN
.env

/target
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl Client {

pub fn from_env() -> Result<Self, Error> {
let _ = dotenv();
let api_token = env::var("AA_API_TOKEN").unwrap();
let api_token = env::var("PHARIA_AI_TOKEN").unwrap();
let base_url = env::var("AA_BASE_URL").unwrap();
Self::with_base_url(base_url, api_token)
}
Expand Down
54 changes: 27 additions & 27 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use dotenv::dotenv;
use futures_util::StreamExt;
use image::ImageFormat;

fn api_token() -> &'static str {
static AA_API_TOKEN: OnceLock<String> = OnceLock::new();
AA_API_TOKEN.get_or_init(|| {
fn pharia_ai_token() -> &'static str {
static PHARIA_AI_TOKEN: OnceLock<String> = OnceLock::new();
PHARIA_AI_TOKEN.get_or_init(|| {
drop(dotenv());
std::env::var("AA_API_TOKEN")
.expect("AA_API_TOKEN environment variable must be specified to run tests.")
std::env::var("PHARIA_AI_TOKEN")
.expect("PHARIA_AI_TOKEN environment variable must be specified to run tests.")
})
}

Expand All @@ -35,7 +35,7 @@ async fn chat_with_pharia_1_7b_base() {
let task = TaskChat::with_message(message);

let model = "pharia-1-llm-7b-control";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client.chat(&task, model, &How::default()).await.unwrap();

// Then
Expand All @@ -48,7 +48,7 @@ async fn completion_with_luminous_base() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -62,16 +62,16 @@ async fn completion_with_luminous_base() {

#[tokio::test]
async fn request_authentication_has_priority() {
let bad_aa_api_token = "DUMMY";
let bad_pharia_ai_token = "DUMMY";
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(base_url(), bad_aa_api_token).unwrap();
let client = Client::with_base_url(base_url(), bad_pharia_ai_token).unwrap();
let response = client
.output_of(
&task.with_model(model),
&How {
api_token: Some(api_token().to_owned()),
api_token: Some(pharia_ai_token().to_owned()),
..Default::default()
},
)
Expand All @@ -96,7 +96,7 @@ async fn authentication_only_per_request() {
.output_of(
&task.with_model(model),
&How {
api_token: Some(api_token().to_owned()),
api_token: Some(pharia_ai_token().to_owned()),
..Default::default()
},
)
Expand Down Expand Up @@ -139,7 +139,7 @@ async fn semanitc_search_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);
let query = Prompt::from_text("What is Pizza?");
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let robot_embedding_task = TaskSemanticEmbedding {
Expand Down Expand Up @@ -202,7 +202,7 @@ async fn complete_structured_prompt() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -231,7 +231,7 @@ async fn maximum_tokens_none_request() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -252,7 +252,7 @@ async fn explain_request() {
target: " How is it going?",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence),
};
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -282,7 +282,7 @@ async fn explain_request_with_auto_granularity() {
target: " How is it going?",
granularity: Granularity::default(),
};
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -314,7 +314,7 @@ async fn explain_request_with_image_modality() {
target: " a cat.",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph),
};
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -364,7 +364,7 @@ async fn describe_image_starting_from_a_path() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -393,7 +393,7 @@ async fn describe_image_starting_from_a_dyn_image() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -419,7 +419,7 @@ async fn only_answer_with_specific_animal() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -446,7 +446,7 @@ async fn answer_should_continue() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -473,7 +473,7 @@ async fn batch_semanitc_embed_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);

let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let embedding_task = TaskBatchSemanticEmbedding {
Expand All @@ -498,7 +498,7 @@ async fn tokenization_with_luminous_base() {
// Given
let input = "Hello, World!";

let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let task1 = TaskTokenization::new(input, false, true);
Expand Down Expand Up @@ -535,7 +535,7 @@ async fn detokenization_with_luminous_base() {
// Given
let input = vec![49222, 15, 5390, 4];

let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let task = TaskDetokenization { token_ids: &input };
Expand All @@ -552,7 +552,7 @@ async fn detokenization_with_luminous_base() {
#[tokio::test]
async fn fetch_tokenizer_for_pharia_1_llm_7b() {
// Given
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();

// When
let tokenizer = client
Expand All @@ -567,7 +567,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {
#[tokio::test]
async fn stream_completion() {
// Given a streaming completion task
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let task = TaskCompletion::from_text("").with_maximum_tokens(7);

// When the events are streamed and collected
Expand Down Expand Up @@ -600,7 +600,7 @@ async fn stream_completion() {
#[tokio::test]
async fn stream_chat_with_pharia_1_llm_7b() {
// Given a streaming completion task
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let message = Message::user("Hello,");
let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7);

Expand Down

0 comments on commit 827b44c

Please sign in to comment.