diff --git a/pull_request_codecommit/git/remote.py b/pull_request_codecommit/git/remote.py index 5a181d4..2c2ba09 100644 --- a/pull_request_codecommit/git/remote.py +++ b/pull_request_codecommit/git/remote.py @@ -1,5 +1,6 @@ import re from typing import Optional +from enum import Enum class Remote: @@ -7,11 +8,18 @@ class Remote: Understands remote CodeCommit URLs """ + class Protocol(Enum): + SSH = 1 + HTTPS = 2 + def __init__(self, url: str): self.__url: str = url self.__region: Optional[str] = "" self.__profile: Optional[str] = "" self.__name: str = "" + self.__protocol: Remote.Protocol = ( + Remote.Protocol.SSH if url.startswith("ssh://") else Remote.Protocol.HTTPS + ) def __regex(self, pattern: str, index: int = 1) -> Optional[str]: match = re.search(pattern, self.__url) @@ -19,7 +27,12 @@ def __regex(self, pattern: str, index: int = 1) -> Optional[str]: @property def supported(self) -> bool: - return self.__url.startswith("codecommit:") and self.name != "" + return ( + self.__protocol == Remote.Protocol.HTTPS + and self.__url.startswith("codecommit:") + or self.__protocol == Remote.Protocol.SSH + and self.__url.startswith("ssh://") + ) and self.name != "" @property def url(self) -> str: @@ -28,7 +41,10 @@ def url(self) -> str: @property def region(self) -> Optional[str]: if not self.__region: - self.__region = self.__regex(r"^codecommit::(.*)://") + if self.__protocol == Remote.Protocol.HTTPS: + self.__region = self.__regex(r"^codecommit::(.*)://") + elif self.__protocol == Remote.Protocol.SSH: + self.__region = self.__regex(r"ssh://git-codecommit.(.*).amazonaws.com") return self.__region @@ -42,7 +58,10 @@ def profile(self) -> Optional[str]: @property def name(self) -> str: if not self.__name: - name = self.__regex(r"(\/\/|.*@)(.*)$", 2) + if self.__protocol == Remote.Protocol.HTTPS: + name = self.__regex(r"(\/\/|.*@)(.*)$", 2) + elif self.__protocol == Remote.Protocol.SSH: + name = self.__regex(r"/([^/]*)$", 1) if name: self.__name = name