Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions src/intelstream/discord/cogs/source_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


_TWITTER_USERNAME_RE = re.compile(r"^[A-Za-z0-9_]{1,15}$")
_ARXIV_CATEGORY_RE = re.compile(r"^[A-Za-z0-9.-]+$")


def _is_valid_twitter_username(username: str) -> bool:
Expand All @@ -36,6 +37,45 @@ class InvalidSourceURLError(ValueError):
pass


def _parse_arxiv_identifier(value: str) -> str:
candidate = value.strip()
if not candidate:
raise InvalidSourceURLError("Arxiv category cannot be empty.")

if candidate.startswith(("arxiv.org/", "www.arxiv.org/")):
candidate = f"https://{candidate}"

parsed = urlparse(candidate)
identifier = candidate

if parsed.scheme or parsed.netloc:
host = parsed.netloc.lower()
if host not in ("arxiv.org", "www.arxiv.org"):
raise InvalidSourceURLError(f"Invalid Arxiv URL: {value}. Expected arxiv.org domain.")

path = parsed.path.rstrip("/")
if path.startswith("/list/"):
parts = [part for part in path.split("/") if part]
if len(parts) < 2:
raise InvalidSourceURLError(
f"Invalid Arxiv URL: {value}. Could not extract category."
)
identifier = parts[1]
elif path.startswith("/rss/"):
identifier = path.split("/rss/", 1)[1].strip("/")
else:
raise InvalidSourceURLError(
f"Invalid Arxiv URL: {value}. Expected an arxiv.org list or RSS URL."
)

if not identifier or not _ARXIV_CATEGORY_RE.fullmatch(identifier):
raise InvalidSourceURLError(
f"Invalid Arxiv category: {value}. Expected format like cs.AI or stat.ML."
)

return identifier


def parse_source_identifier(source_type: SourceType, url: str) -> tuple[str, str | None]:
parsed = urlparse(url)

Expand Down Expand Up @@ -86,9 +126,7 @@ def parse_source_identifier(source_type: SourceType, url: str) -> tuple[str, str
return identifier, url

elif source_type == SourceType.ARXIV:
identifier = url.strip()
if not identifier:
raise InvalidSourceURLError("Arxiv category cannot be empty.")
identifier = _parse_arxiv_identifier(url)
feed_url = f"https://arxiv.org/rss/{identifier}"
return identifier, feed_url

Expand Down Expand Up @@ -248,7 +286,14 @@ async def source_add(
)
return

safe, error_msg = is_safe_url(url)
try:
identifier, feed_url = parse_source_identifier(stype, url)
except InvalidSourceURLError as e:
await interaction.followup.send(str(e), ephemeral=True)
return

validation_url = feed_url if stype == SourceType.ARXIV and feed_url else url
safe, error_msg = is_safe_url(validation_url)
if not safe:
await interaction.followup.send(f"URL not allowed: {error_msg}", ephemeral=True)
return
Expand Down Expand Up @@ -298,12 +343,6 @@ async def source_add(
)
return

try:
identifier, feed_url = parse_source_identifier(stype, url)
except InvalidSourceURLError as e:
await interaction.followup.send(str(e), ephemeral=True)
return

existing = await self.bot.repository.get_source_by_identifier(identifier)
if existing:
await interaction.followup.send(
Expand Down
103 changes: 103 additions & 0 deletions tests/test_discord/test_source_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,39 @@ def test_parse_arxiv_empty(self):
with pytest.raises(InvalidSourceURLError, match="cannot be empty"):
parse_source_identifier(SourceType.ARXIV, " ")

def test_parse_arxiv_category(self):
identifier, feed_url = parse_source_identifier(SourceType.ARXIV, "cs.AI")
assert identifier == "cs.AI"
assert feed_url == "https://arxiv.org/rss/cs.AI"

def test_parse_arxiv_list_url(self):
identifier, feed_url = parse_source_identifier(
SourceType.ARXIV,
"https://arxiv.org/list/cs.AI/",
)
assert identifier == "cs.AI"
assert feed_url == "https://arxiv.org/rss/cs.AI"

def test_parse_arxiv_recent_list_url(self):
identifier, feed_url = parse_source_identifier(
SourceType.ARXIV,
"https://arxiv.org/list/cs.AI/recent",
)
assert identifier == "cs.AI"
assert feed_url == "https://arxiv.org/rss/cs.AI"

def test_parse_arxiv_rss_url(self):
identifier, feed_url = parse_source_identifier(
SourceType.ARXIV,
"https://arxiv.org/rss/stat.ML",
)
assert identifier == "stat.ML"
assert feed_url == "https://arxiv.org/rss/stat.ML"

def test_parse_arxiv_wrong_domain(self):
with pytest.raises(InvalidSourceURLError, match=r"Expected arxiv\.org domain"):
parse_source_identifier(SourceType.ARXIV, "https://example.com/list/cs.AI/")

def test_parse_page_no_host(self):
with pytest.raises(InvalidSourceURLError, match="No host found"):
parse_source_identifier(SourceType.PAGE, "not-a-url")
Expand Down Expand Up @@ -366,6 +399,76 @@ async def test_add_source_with_summarize_true(self, source_management, mock_bot)
call_kwargs = mock_bot.repository.add_source.call_args.kwargs
assert call_kwargs["skip_summary"] is False

async def test_add_arxiv_source_with_category_identifier(self, source_management, mock_bot):
interaction = MagicMock(spec=discord.Interaction)
interaction.response = MagicMock()
interaction.response.defer = AsyncMock()
interaction.followup = MagicMock()
interaction.followup.send = AsyncMock()
interaction.user = MagicMock()
interaction.user.id = 123
interaction.guild_id = 456
interaction.channel_id = 789

source_type_choice = MagicMock()
source_type_choice.value = "arxiv"
source_type_choice.name = "Arxiv"

mock_bot.repository.get_source_by_identifier = AsyncMock(return_value=None)
mock_bot.repository.get_source_by_name = AsyncMock(return_value=None)

mock_source = MagicMock()
mock_source.id = "new-source-id"
mock_bot.repository.add_source = AsyncMock(return_value=mock_source)

await source_management.source_add.callback(
source_management,
interaction,
source_type=source_type_choice,
name="arxiv_cs.AI",
url="cs.AI",
)

mock_bot.repository.add_source.assert_called_once()
call_kwargs = mock_bot.repository.add_source.call_args.kwargs
assert call_kwargs["identifier"] == "cs.AI"
assert call_kwargs["feed_url"] == "https://arxiv.org/rss/cs.AI"

async def test_add_arxiv_source_with_list_url(self, source_management, mock_bot):
interaction = MagicMock(spec=discord.Interaction)
interaction.response = MagicMock()
interaction.response.defer = AsyncMock()
interaction.followup = MagicMock()
interaction.followup.send = AsyncMock()
interaction.user = MagicMock()
interaction.user.id = 123
interaction.guild_id = 456
interaction.channel_id = 789

source_type_choice = MagicMock()
source_type_choice.value = "arxiv"
source_type_choice.name = "Arxiv"

mock_bot.repository.get_source_by_identifier = AsyncMock(return_value=None)
mock_bot.repository.get_source_by_name = AsyncMock(return_value=None)

mock_source = MagicMock()
mock_source.id = "new-source-id"
mock_bot.repository.add_source = AsyncMock(return_value=mock_source)

await source_management.source_add.callback(
source_management,
interaction,
source_type=source_type_choice,
name="arxiv_cs.AI",
url="https://arxiv.org/list/cs.AI/",
)

mock_bot.repository.add_source.assert_called_once()
call_kwargs = mock_bot.repository.add_source.call_args.kwargs
assert call_kwargs["identifier"] == "cs.AI"
assert call_kwargs["feed_url"] == "https://arxiv.org/rss/cs.AI"


class TestSourceManagementList:
async def test_list_sources_empty(self, source_management, mock_bot):
Expand Down
Loading