Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes: strip json response from llm #189

Merged
merged 5 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions vibi-dpu/src/graph/file_imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde_json::json;

use crate::{graph::utils::numbered_content, utils::review::Review};

use super::utils::{all_code_files, call_llm_api, read_file};
use super::utils::{all_code_files, call_llm_api, read_file, strip_json_prefix};

// #[derive(Debug, Serialize, Default, Deserialize, Clone)]
// struct LlmImportLineInput {
Expand Down Expand Up @@ -346,7 +346,10 @@ impl ImportIdentifier {
log::debug!("[ImportIdentifier/get_import_path] Unable to call llm api");
return None;
}
let import_path_str = import_path_opt.expect("Empty import_path_opt");
let mut import_path_str = import_path_opt.expect("Empty import_path_opt");
if let Some(stripped_json) = strip_json_prefix(&import_path_str) {
import_path_str = stripped_json.to_string();
}
let import_path_res = serde_json::from_str(&import_path_str);
if import_path_res.is_err() {
log::debug!(
Expand All @@ -355,6 +358,9 @@ impl ImportIdentifier {
return None;
}
let import_path: ImportPathOutput = import_path_res.expect("Unacaught error in import_path_res");
if !import_path.get_matching_import().possible_file_path().is_empty() {
return None;
}
return Some(import_path);
}

Expand Down
17 changes: 14 additions & 3 deletions vibi-dpu/src/graph/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use std::io::BufRead;
use crate::utils::review::Review;

use super::{gitops::{HunkDiffLines, HunkDiffMap}, utils::{call_llm_api, detect_language, numbered_content, read_file}};
use super::{gitops::{HunkDiffLines, HunkDiffMap}, utils::{call_llm_api, detect_language, numbered_content, read_file, strip_json_prefix}};

#[derive(Debug, Serialize, Default, Deserialize, Clone)]
pub struct FunctionCallChunk {
Expand Down Expand Up @@ -177,6 +177,10 @@ impl FunctionCallsOutput {
pub fn function_calls(&self) -> &Vec<FunctionCall> {
return &self.function_calls
}

pub fn trim_empty_function_calls(&mut self) {
self.function_calls.retain(|func_call| !func_call.function_name().is_empty());
}
}

// Instruction structure
Expand Down Expand Up @@ -291,14 +295,21 @@ impl FunctionCallIdentifier {
log::error!("[FunctionCallIdentifier/functions_in_chunk] Unable to call llm for chunk: {:?}", chunk);
return None;
}
let prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
let mut prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
if let Some(stripped_json) = strip_json_prefix(&prompt_response) {
prompt_response = stripped_json.to_string();
}
let deserialized_response = serde_json::from_str(&prompt_response);
if deserialized_response.is_err() {
let e = deserialized_response.expect_err("Empty error in deserialized_response");
log::error!("[FunctionCallIdentifier/functions_in_chunk] Error in deserializing response: {:?}", e);
return None;
}
let func_calls: FunctionCallsOutput = deserialized_response.expect("Empty error in deserialized_response");
let mut func_calls: FunctionCallsOutput = deserialized_response.expect("Empty error in deserialized_response");
if func_calls.function_calls().is_empty() {
return None;
}
func_calls.trim_empty_function_calls();
tapishr marked this conversation as resolved.
Show resolved Hide resolved
return Some(func_calls);
}

Expand Down
7 changes: 5 additions & 2 deletions vibi-dpu/src/graph/function_line_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};

use crate::graph::utils::numbered_content;

use super::{function_call::FunctionCall, gitops::HunkDiffLines, utils::{call_llm_api, read_file}};
use super::{function_call::FunctionCall, gitops::HunkDiffLines, utils::{call_llm_api, read_file, strip_json_prefix}};

#[derive(Debug, Serialize, Default, Deserialize, Clone)]
pub struct FuncDefInfo {
Expand Down Expand Up @@ -181,7 +181,10 @@ impl FunctionDefIdentifier {
log::error!("[FunctionCallIdentifier/functions_in_chunk] Unable to call llm for chunk: {:?}", chunk);
return None;
}
let prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
let mut prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
if let Some(stripped_json) = strip_json_prefix(&prompt_response) {
prompt_response = stripped_json.to_string();
}
let deserialized_response = serde_json::from_str(&prompt_response);
if deserialized_response.is_err() {
let e = deserialized_response.expect_err("Empty error in deserialized_response");
Expand Down
16 changes: 11 additions & 5 deletions vibi-dpu/src/graph/function_name.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};
use super::utils::{call_llm_api, read_file};
use super::utils::{call_llm_api, read_file, strip_json_prefix};

// Struct to represent the output schema
#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -89,15 +89,21 @@ impl FunctionNameIdentifier {
log::error!("[FunctionNameIdentifier/function_name_in_line] Unable to call llm for code line: {:?}", code_line);
return None;
}
let prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
let mut prompt_response = prompt_response_opt.expect("Empty prompt_response_opt");
if let Some(stripped_json) = strip_json_prefix(&prompt_response) {
prompt_response = stripped_json.to_string();
}
tapishr marked this conversation as resolved.
Show resolved Hide resolved
let deserialized_response = serde_json::from_str(&prompt_response);
if deserialized_response.is_err() {
let e = deserialized_response.expect_err("Empty error in deserialized_response");
log::error!("[FunctionNameIdentifier/function_name_in_line] Error in deserializing response: {:?}", e);
return None;
}
let func_calls: FunctionNameOutput = deserialized_response.expect("Empty error in deserialized_response");
self.cached_output.insert(code_line.to_string(), func_calls.get_function_name().to_string());
return Some(func_calls);
let func_name: FunctionNameOutput = deserialized_response.expect("Empty error in deserialized_response");
if func_name.get_function_name().is_empty() {
return None;
}
self.cached_output.insert(code_line.to_string(), func_name.get_function_name().to_string());
return Some(func_name);
tapishr marked this conversation as resolved.
Show resolved Hide resolved
}
}
11 changes: 11 additions & 0 deletions vibi-dpu/src/graph/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,17 @@ pub fn absolute_to_relative_path(abs_path: &str, review: &Review) -> Option<Stri
return Some(rel_path.to_str().expect("Unable to deserialze rel_path").to_string());
}

pub fn strip_json_prefix(json_str: &str) -> Option<&str> {
if let Some(start) = json_str.find("```json") {
// Find the end of "```" after the "```json"
if let Some(end) = json_str[start + 7..].find("```") {
// Return the substring between "```json" and "```"
return Some(&json_str[start + 7..start + 7 + end]);
}
}
return None;
}

// Generate a map of file extensions to languages or frameworks
fn get_extension_map() -> HashMap<&'static str, &'static str> {
let mut extension_map = HashMap::new();
Expand Down
Loading