Skip to content

Commit

Permalink
misc: small fix or general refactoring i did not bother commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Nov 21, 2024
1 parent 382333b commit aa7909e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 64 deletions.
124 changes: 67 additions & 57 deletions nerve-core/src/agent/generator/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,71 @@ impl OpenAIClient {

Ok(Self { model, client })
}

async fn get_tools_if_supported(&self, state: &SharedState) -> Vec<FunctionTool> {
let mut tools = vec![];

// if native tool calls are supported (and XML was not forced)
if state.lock().await.native_tools_support {
// for every namespace available to the model
for group in state.lock().await.get_namespaces() {
// for every action of the namespace
for action in &group.actions {
let mut required = vec![];
let mut properties = HashMap::new();

if let Some(example) = action.example_payload() {
required.push("payload".to_string());
properties.insert(
"payload".to_string(),
OpenAiToolFunctionParameterProperty {
the_type: "string".to_string(),
description: format!(
"The main function argument, use this as a template: {}",
example
),
},
);
}

if let Some(attrs) = action.example_attributes() {
for name in attrs.keys() {
required.push(name.to_string());
properties.insert(
name.to_string(),
OpenAiToolFunctionParameterProperty {
the_type: "string".to_string(),
description: name.to_string(),
},
);
}
}

let function = FunctionDefinition {
name: action.name().to_string(),
description: Some(action.description().to_string()),
parameters: Some(serde_json::json!(OpenAiToolFunctionParameters {
the_type: "object".to_string(),
required,
properties,
})),
};

tools.push(FunctionTool {
the_type: "function".to_string(),
function,
});
}
}

log::trace!("openai.tools={:?}", &tools);

// let j = serde_json::to_string_pretty(&tools).unwrap();
// log::info!("{j}");
}

tools
}
}

#[async_trait]
Expand All @@ -74,7 +139,7 @@ impl Client for OpenAIClient {
},
openai_api_rust::Message {
role: Role::User,
content: Some("Call the test function.".to_string()),
content: Some("Execute the test function.".to_string()),
tool_calls: None,
},
];
Expand Down Expand Up @@ -154,62 +219,7 @@ impl Client for OpenAIClient {
});
}

let mut tools = vec![];
if state.lock().await.native_tools_support {
for group in state.lock().await.get_namespaces() {
for action in &group.actions {
let mut required = vec![];
let mut properties = HashMap::new();

if let Some(example) = action.example_payload() {
required.push("payload".to_string());
properties.insert(
"payload".to_string(),
OpenAiToolFunctionParameterProperty {
the_type: "string".to_string(),
description: format!(
"The main function argument, use this as a template: {}",
example
),
},
);
}

if let Some(attrs) = action.example_attributes() {
for name in attrs.keys() {
required.push(name.to_string());
properties.insert(
name.to_string(),
OpenAiToolFunctionParameterProperty {
the_type: "string".to_string(),
description: name.to_string(),
},
);
}
}

let function = FunctionDefinition {
name: action.name().to_string(),
description: Some(action.description().to_string()),
parameters: Some(serde_json::json!(OpenAiToolFunctionParameters {
the_type: "object".to_string(),
required,
properties,
})),
};

tools.push(FunctionTool {
the_type: "function".to_string(),
function,
});
}
}

log::trace!("openai.tools={:?}", &tools);

// let j = serde_json::to_string_pretty(&tools).unwrap();
// log::info!("{j}");
}
let tools = self.get_tools_if_supported(&state).await;

let body = ChatBody {
model: self.model.to_string(),
Expand Down
22 changes: 15 additions & 7 deletions nerve-core/src/agent/namespaces/filesystem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,21 @@ impl Action for AppendToFile {
}
};

// parse the payload as a JSON object
let one_line_json = if let Ok(value) = serde_json::from_str::<serde_json::Value>(&payload) {
// reconvert to make sure it's on a single line
serde_json::to_string(&value).unwrap()
// get lowercase file extension from filepath
let extension = filepath.rsplit('.').next().unwrap_or("").to_lowercase();

let content_to_append = if extension == "json" || extension == "jsonl" {
// parse the payload as a JSON object
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&payload) {
// reconvert to make sure it's on a single line
serde_json::to_string(&value).unwrap()
} else {
log::error!("can't parse payload as JSON: {}", payload);
serde_json::to_string(&InvalidJSON { data: payload }).unwrap()
}
} else {
log::error!("can't parse payload as JSON: {}", payload);
serde_json::to_string(&InvalidJSON { data: payload }).unwrap()
// add as it is
payload
};

// append the JSON to the file
Expand All @@ -217,7 +225,7 @@ impl Action for AppendToFile {
.create(true)
.open(&filepath)?;

writeln!(file, "{}", one_line_json)?;
writeln!(file, "{}", content_to_append)?;

Ok(None)
}
Expand Down

0 comments on commit aa7909e

Please sign in to comment.