Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arxiv_search service function #148

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"pymongo",
"pymysql",
"beautifulsoup4",
"feedparser",
]

doc_requires = [
Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .sql_query.sqlite import query_sqlite
from .sql_query.mongodb import query_mongodb
from .web_search.search import bing_search, google_search
from .web_search.arxiv import arxiv_search
from .service_response import ServiceResponse
from .service_factory import ServiceFactory
from .retrieval.similarity import cos_sim
Expand Down Expand Up @@ -46,6 +47,7 @@ def get_help() -> None:
"write_json_file",
"bing_search",
"google_search",
"arxiv_search",
"query_mysql",
"query_sqlite",
"query_mongodb",
Expand Down
293 changes: 293 additions & 0 deletions src/agentscope/service/web_search/arxiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
"""Search papers in arXiv API. This implementation refers to the repository
https://github.com/lukasschwab/arxiv.py, which is MIT licensed.
"""
import json
import re
import time
import urllib
from calendar import timegm
from datetime import datetime, timezone
from typing import List, Optional, Union

try:
import feedparser
except ImportError:
feedparser = None
from loguru import logger

from agentscope.service.service_response import (
ServiceResponse,
ServiceExecStatus,
)

ARXIV_SEARCH_URL = "http://export.arxiv.org/api/query?{parameters_str}"

LOGIC_OPERATORS = ["ANDNOT", "AND", "OR"]

SYMBOLS = ["(", ")"]

QUERY_PREFIX = ["all:", "ti:", "au:", "abs:", "co:", "jr:", "cat:", "rn:"]


class _Result(dict):
"""The class for arXiv search results."""

__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__

id: str
"""A url of the form `https://arxiv.org/abs/{id}`."""

title: str
"""The title of the result."""

updated: str
"""When the result was last updated."""

published: str
"""When the result was published."""

summary: str
"""The summary of the search result."""

authors: List[str]
"""The authors of the search result."""

comment: Optional[str]
"""The authors' comment if present."""

primary_category: Optional[str]
"""The result's primary arXiv category. See [arXiv: Category
Taxonomy](https://arxiv.org/category_taxonomy)."""

tags: List[str]
"""All of the result's tags. See [arXiv: Category
Taxonomy](https://arxiv.org/category_taxonomy)."""

journal_ref: Optional[str]
"""A journal reference if present."""

doi: Optional[str]
"""A URL for the resolved DOI to an external resource if present."""

def __init__(
self,
entry_id: str,
title: str,
updated: str,
published: str,
summary: str,
authors: List[str],
pdf_url: Optional[str] = None,
comment: Optional[str] = None,
primary_category: Optional[str] = None,
tags: List[str] = None,
journal_ref: Optional[str] = None,
doi: Optional[str] = None,
) -> None:
"""The class for arXiv search results."""
self.entry_id = entry_id
self.title = title
self.updated = updated
self.published = published
self.summary = summary
self.authors = authors
self.pdf_url = pdf_url
self.comment = comment
self.primary_category = primary_category
self.tags = tags
self.journal_ref = journal_ref
self.doi = doi

def __str__(self) -> str:
cleaned_dict = {}
for key in self:
if self[key] is not None:
cleaned_dict[key] = self[key]
return json.dumps(cleaned_dict, ensure_ascii=False)

def __repr__(self) -> str:
return self.__str__()


def _parse_pdf_url(links: List) -> Union[str, None]:
"""Parse the pdf url from the links."""
for link in links:
if link.get("title") == "pdf":
return link.get("href")
return None


def _parse_timestamp(timestamp: time.struct_time) -> str:
"""Parse the timestamp to a string."""
timestamp = datetime.fromtimestamp(timegm(timestamp), tz=timezone.utc)
return timestamp.strftime("%Y-%m-%d %H:%M:%S")


def _clean_arxiv_search_results(result: dict) -> dict:
"""Clean the arXiv search results, and remove unnecessary information."""
feed = result.feed

# Basic information
cleaned_dict = {
"updated": _parse_timestamp(feed.updated_parsed),
"opensearch_total_results": int(feed.opensearch_totalresults),
"opensearch_start_index": int(feed.opensearch_startindex),
"opensearch_itemsperpage": int(feed.opensearch_itemsperpage),
}

# Entries
entries = []
for entry in result.entries:
title = "0"
if hasattr(entry, "title"):
title = entry.title
else:
logger.warning(
"Result %s is missing title attribute; defaulting to '0'",
entry.id,
)

tags = [tag.get("term") for tag in entry.tags]
if len(tags) == 0:
tags = None

entry_dict = _Result(
# Basic properties
entry_id=entry.id,
title=title,
updated=_parse_timestamp(entry.updated_parsed),
published=_parse_timestamp(entry.published_parsed),
summary=entry.summary,
authors=[author.name for author in entry.authors],
# Optional properties
pdf_url=_parse_pdf_url(entry.links),
comment=entry.get("arxiv_comment"),
primary_category=entry.arxiv_primary_category.get("term"),
tags=tags,
journal_ref=entry.get("arxiv_journal_ref"),
doi=entry.get("arxiv_doi"),
)

entries.append(entry_dict)

cleaned_dict["entries"] = entries

return cleaned_dict


def _reformat_query(query: str) -> str:
"""Reformat the query string for arxiv search, refer to
https://info.arxiv.org/help/api/user-manual.html."""
delimiter_regex = (
"("
+ "|".join(
map(re.escape, LOGIC_OPERATORS + QUERY_PREFIX + SYMBOLS),
)
+ ")"
)

parts = re.split(delimiter_regex, query)

parts = [part.strip() for part in parts if part.strip()]

for i, part in enumerate(parts):
if part not in LOGIC_OPERATORS + QUERY_PREFIX + SYMBOLS:
# Add double quotes if it does not contain double quotes
part = part.replace('"', "%22").replace(" ", "+")

if not part.startswith("%22"):
part = f"%22{part}"
if not part.endswith("%22"):
part = f"{part}%22"
parts[i] = part
elif part in SYMBOLS:
parts[i] = part.replace("(", "%28").replace(")", "%29")
elif part in LOGIC_OPERATORS:
parts[i] = f"+{part}+"

refined_query = "".join(parts)

return refined_query


def arxiv_search(
search_query: str,
id_list: List[str] = None,
start: int = 0,
max_results: Optional[int] = None,
) -> ServiceResponse:
"""Search arXiv paper by a given query string.
Args:
search_query (`str`):
The query string, supporting prefixes "all:", "ti:", "au:",
"abs:", "co:", "jr:", "cat:", and "rn:", boolean operators "AND",
"OR" and "ANDNOT". For example, searching for papers with
title "Deep Learning" and author "LeCun" by a
search_query ti:"Deep Learning" AND au:"LeCun"
id_list (`List[str]`, defaults to `None`):
A list of arXiv IDs to search.
start (`int`, defaults to `0`):
The index of the first search result to return.
max_results (`Optional[int]`, defaults to `None`):
The maximum number of search results to return.
Returns:
`ServiceResponse`: A dictionary with two variables: `status` and
`content`. The `status` variable is from the ServiceExecStatus enum,
and `content` is a list of search results or error information,
which depends on the `status` variable.
"""

if feedparser is None:
raise ImportError(
"The `feedparser` module is not installed. Please install it by "
"running `pip install feedparser`.",
)

# construct url
search_query = _reformat_query(search_query)

parameters = {"search_query": search_query}

if id_list:
parameters["id_list"] = ",".join(id_list)

if start > 0:
parameters["start"] = str(start)

if max_results:
parameters["max_results"] = str(max_results)

parameters_str = "&".join([f"{k}={v}" for k, v in parameters.items()])

url = ARXIV_SEARCH_URL.format(parameters_str=parameters_str)

try:
logger.debug(f"Searching arXiv by url: {url}")

with urllib.request.urlopen(url) as data:
# Parse the results by feedparser
feedparser_dict = feedparser.parse(data.read().decode("utf-8"))

# Remove unnecessary information
results = _clean_arxiv_search_results(feedparser_dict)

if data.code == 200:
# Return the searching results
return ServiceResponse(
status=ServiceExecStatus.SUCCESS,
content=results,
)
else:
return ServiceResponse(
status=ServiceExecStatus.ERROR,
content=f"Error: {data.code}, {data}",
)
except Exception as e:
return ServiceResponse(
status=ServiceExecStatus.ERROR,
content=f"Error: {e}",
)
29 changes: 28 additions & 1 deletion tests/web_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import unittest
from unittest.mock import Mock, patch, MagicMock

from agentscope.service import ServiceResponse
from agentscope.service import ServiceResponse, arxiv_search
from agentscope.service import bing_search, google_search
from agentscope.service.service_status import ServiceExecStatus
from agentscope.service.web_search.arxiv import _reformat_query


class TestWebSearches(unittest.TestCase):
Expand Down Expand Up @@ -124,6 +125,32 @@ def test_search_google(self, mock_get: MagicMock) -> None:
)
self.assertEqual(results, expected_result)

def test_arxiv_search(self) -> None:
"""test arxiv search"""
res = arxiv_search(
search_query="ti:Agentscope",
id_list=["2402.14034"],
max_results=1,
)
self.assertEqual(
res.content["entries"][0]["title"],
"AgentScope: A Flexible yet Robust Multi-Agent Platform",
)

def test_arxiv_query_format(self) -> None:
"""Test arxiv query format."""
res = _reformat_query(
'ti: "Deep Learning" ANDau:LeCun OR (ti: machine learning '
"ANDNOT au:John Doe)",
)

ground_truth = (
"ti:%22Deep+Learning%22+AND+au:%22LeCun%22+OR+%28ti:"
"%22machine+learning%22+ANDNOT+au:%22John+Doe%22%29"
)

self.assertEqual(ground_truth, res)


# This allows the tests to be run from the command line
if __name__ == "__main__":
Expand Down