Skip to content

Commit

Permalink
Feat/tools/gitlab (langgenius#10407)
Browse files Browse the repository at this point in the history
  • Loading branch information
wlrnet authored Nov 8, 2024
1 parent 0e8ab05 commit c9f785e
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 48 deletions.
7 changes: 3 additions & 4 deletions api/core/rag/extractor/word_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class WordExtractor(BaseExtractor):
"""Load docx files.
Args:
file_path: Path to the file to load.
"""
Expand All @@ -51,9 +50,9 @@ def __init__(self, file_path: str, tenant_id: str, user_id: str):

self.web_path = self.file_path
# TODO: use a better way to handle the file
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
with tempfile.NamedTemporaryFile(delete=False) as self.temp_file:
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")

Expand Down
66 changes: 29 additions & 37 deletions api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class GitlabCommitsTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
branch = tool_parameters.get("branch", "")
repository = tool_parameters.get("repository", "")
employee = tool_parameters.get("employee", "")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
change_type = tool_parameters.get("change_type", "all")

if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not repository:
return self.create_text_message("Either repository is required")

if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
Expand All @@ -37,22 +37,18 @@ def _invoke(
site_url = "https://gitlab.com"

# Get commit content
if repository:
result = self.fetch_commits(
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
)
else:
result = self.fetch_commits(
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
)
result = self.fetch_commits(
site_url, access_token, repository, branch, employee, start_time, end_time, change_type, is_repository=True
)

return [self.create_json_message(item) for item in result]

def fetch_commits(
self,
site_url: str,
access_token: str,
identifier: str,
repository: str,
branch: str,
employee: str,
start_time: str,
end_time: str,
Expand All @@ -64,27 +60,14 @@ def fetch_commits(
results = []

try:
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
else:
# Get all projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()

filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]

for project in filtered_projects:
project_id = project["id"]
project_name = project["name"]
print(f"Project: {project_name}")

commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
# URL encode the repository path
encoded_repository = urllib.parse.quote(repository, safe="")
commits_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits"

# Fetch commits for the repository
params = {"since": start_time, "until": end_time}
if branch:
params["ref_name"] = branch
if employee:
params["author"] = employee

Expand All @@ -96,10 +79,7 @@ def fetch_commits(
commit_sha = commit["id"]
author_name = commit["author_name"]

if is_repository:
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
else:
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits/{commit_sha}/diff"

diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
Expand All @@ -120,7 +100,14 @@ def fetch_commits(
if line.startswith("+") and not line.startswith("+++")
]
)
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
results.append(
{
"diff_url": diff_url,
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code,
}
)
else:
if total_changes > 1:
final_code = "".join(
Expand All @@ -134,7 +121,12 @@ def fetch_commits(
)
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append(
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
{
"diff_url": diff_url,
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code_escaped,
}
)
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
Expand Down
14 changes: 7 additions & 7 deletions api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ parameters:
form: llm
- name: repository
type: string
required: false
required: true
label:
en_US: repository
zh_Hans: 仓库路径
Expand All @@ -32,16 +32,16 @@ parameters:
zh_Hans: 仓库路径,以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: project
- name: branch
type: string
required: false
label:
en_US: project
zh_Hans: 项目名
en_US: branch
zh_Hans: 分支名
human_description:
en_US: project
zh_Hans: 项目名
llm_description: project for GitLab
en_US: branch
zh_Hans: 分支名
llm_description: branch for GitLab
form: llm
- name: start_time
type: string
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import urllib.parse
from typing import Any, Union

import requests

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class GitlabMergeRequestsTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
repository = tool_parameters.get("repository", "")
branch = tool_parameters.get("branch", "")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
state = tool_parameters.get("state", "opened") # Default to "opened"

if not repository:
return self.create_text_message("Repository is required")

access_token = self.runtime.credentials.get("access_tokens")
site_url = self.runtime.credentials.get("site_url")

if not access_token:
return self.create_text_message("Gitlab API Access Tokens is required.")
if not site_url:
site_url = "https://gitlab.com"

# Get merge requests
result = self.get_merge_requests(site_url, access_token, repository, branch, start_time, end_time, state)

return [self.create_json_message(item) for item in result]

def get_merge_requests(
self, site_url: str, access_token: str, repository: str, branch: str, start_time: str, end_time: str, state: str
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []

try:
# URL encode the repository path
encoded_repository = urllib.parse.quote(repository, safe="")
merge_requests_url = f"{domain}/api/v4/projects/{encoded_repository}/merge_requests"
params = {"state": state}

# Add time filters if provided
if start_time:
params["created_after"] = start_time
if end_time:
params["created_before"] = end_time

response = requests.get(merge_requests_url, headers=headers, params=params)
response.raise_for_status()
merge_requests = response.json()

for mr in merge_requests:
# Filter by target branch
if branch and mr["target_branch"] != branch:
continue

results.append(
{
"id": mr["id"],
"title": mr["title"],
"author": mr["author"]["name"],
"web_url": mr["web_url"],
"target_branch": mr["target_branch"],
"created_at": mr["created_at"],
"state": mr["state"],
}
)
except requests.RequestException as e:
print(f"Error fetching merge requests from GitLab: {e}")

return results
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
identity:
name: gitlab_mergerequests
author: Leo.Wang
label:
en_US: GitLab Merge Requests
zh_Hans: GitLab 合并请求查询
description:
human:
en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。
llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
parameters:
- name: repository
type: string
required: false
label:
en_US: repository
zh_Hans: 仓库路径
human_description:
en_US: repository
zh_Hans: 仓库路径,以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: branch
type: string
required: false
label:
en_US: branch
zh_Hans: 分支名
human_description:
en_US: branch
zh_Hans: 分支名
llm_description: branch for GitLab
form: llm
- name: start_time
type: string
required: false
label:
en_US: start_time
zh_Hans: 开始时间
human_description:
en_US: start_time
zh_Hans: 开始时间
llm_description: Start time for GitLab
form: llm
- name: end_time
type: string
required: false
label:
en_US: end_time
zh_Hans: 结束时间
human_description:
en_US: end_time
zh_Hans: 结束时间
llm_description: End time for GitLab
form: llm
- name: state
type: select
required: false
options:
- value: opened
label:
en_US: opened
zh_Hans: 打开
- value: closed
label:
en_US: closed
zh_Hans: 关闭
default: opened
label:
en_US: state
zh_Hans: 变更状态
human_description:
en_US: state
zh_Hans: 变更状态
llm_description: Merge request state type for GitLab
form: llm
Loading

0 comments on commit c9f785e

Please sign in to comment.