代码优化,增强清晰度和可维护性。

This commit is contained in:
张建平 2025-02-27 11:07:34 +08:00
parent 1988a400c9
commit de84796560
67 changed files with 5417 additions and 2190 deletions

3
.gitignore vendored
View File

@ -13,4 +13,5 @@ docs/.cache/
.qodo
db.sqlite3
#pr_agent/
static/admin/
static/admin/
config.local.ini

View File

@ -20,9 +20,9 @@ pygithub = "*"
python-gitlab = "*"
retry = "*"
fastapi = "*"
psycopg2-binary = "*"
[dev-packages]
[requires]
python_version = "3.12"

106
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "420206f7faa4351eabc368a83deae9b7ed9e50b0975ac63a46d6367e9920848b"
"sha256": "497c1ff8497659883faf8dcca407665df1b3a37f67720f64b139f9dec8202892"
},
"pipfile-spec": 6,
"requires": {
@ -169,20 +169,19 @@
},
"boto3": {
"hashes": [
"sha256:01015b38017876d79efd7273f35d9a4adfba505237159621365bed21b9b65eca",
"sha256:03bd8c93b226f07d944fd6b022e11a307bff94ab6a21d51675d7e3ea81ee8424"
"sha256:e58136d52d79425ce26c3c1578bf94d4b2e91ead55fed9f6950406ee9713e6af"
],
"index": "pip_conf_index_global",
"markers": "python_version >= '3.8'",
"version": "==1.37.0"
"version": "==1.37.2"
},
"botocore": {
"hashes": [
"sha256:b129d091a8360b4152ab65327186bf4e250de827c4a9b7ddf40a72b1acf1f3c1",
"sha256:d01661f38c0edac87424344cdf4169f3ab9bc1bf1b677c8b230d025eb66c54a3"
"sha256:3f460f3c32cd6d747d5897a9cbde011bf1715abc7bf0a6ea6fdb0b812df63287",
"sha256:5f59b966f3cd0c8055ef6f7c2600f7db5f8218071d992e5f95da3f9156d4370f"
],
"markers": "python_version >= '3.8'",
"version": "==1.37.0"
"version": "==1.37.2"
},
"certifi": {
"hashes": [
@ -460,12 +459,12 @@
},
"django-import-export": {
"hashes": [
"sha256:317842a64233025a277040129fb6792fc48fd39622c185b70bf8c18c393d708f",
"sha256:ecb4e6cdb4790d69bce261f9cca1007ca19cb431bb5a950ba907898245c8817b"
"sha256:5514d09636e84e823a42cd5e79292f70f20d6d2feed117a145f5b64a5b44f168",
"sha256:bd3fe0aa15a2bce9de4be1a2f882e2c4539fdbfdfa16f2052c98dd7aec0f085c"
],
"index": "pip_conf_index_global",
"markers": "python_version >= '3.9'",
"version": "==4.3.6"
"version": "==4.3.7"
},
"django-simpleui": {
"hashes": [
@ -794,12 +793,12 @@
},
"litellm": {
"hashes": [
"sha256:02df5865f98ea9734a4d27ac7c33aad9a45c4015403d5c0797d3292ade3c5cb5",
"sha256:d241436ac0edf64ec57fb5686f8d84a25998a7e52213d9063adf87df8432701f"
"sha256:eaab989c090ccc094b41c3fdf27d1df7f6fb25e091ab0ce48e0f3079f1e51ff5",
"sha256:ff9137c008cdb421db32defb1fbd1ed546a95167de6d276c61b664582ed4ff60"
],
"index": "pip_conf_index_global",
"markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'",
"version": "==1.61.16"
"version": "==1.61.17"
},
"loguru": {
"hashes": [
@ -1196,6 +1195,81 @@
"markers": "python_version >= '3.6'",
"version": "==7.0.0"
},
"psycopg2-binary": {
"hashes": [
"sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff",
"sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5",
"sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f",
"sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5",
"sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0",
"sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c",
"sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c",
"sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341",
"sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f",
"sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7",
"sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d",
"sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007",
"sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142",
"sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92",
"sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb",
"sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5",
"sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5",
"sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8",
"sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1",
"sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68",
"sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73",
"sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1",
"sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53",
"sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d",
"sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906",
"sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0",
"sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2",
"sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a",
"sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b",
"sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44",
"sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648",
"sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7",
"sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f",
"sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa",
"sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697",
"sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d",
"sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b",
"sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526",
"sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4",
"sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287",
"sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e",
"sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673",
"sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0",
"sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30",
"sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3",
"sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e",
"sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92",
"sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a",
"sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c",
"sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8",
"sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909",
"sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47",
"sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864",
"sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc",
"sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00",
"sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb",
"sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539",
"sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b",
"sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481",
"sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5",
"sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4",
"sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64",
"sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392",
"sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4",
"sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1",
"sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1",
"sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567",
"sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863"
],
"index": "pip_conf_index_global",
"markers": "python_version >= '3.8'",
"version": "==2.9.10"
},
"py": {
"hashes": [
"sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719",
@ -1766,11 +1840,11 @@
},
"s3transfer": {
"hashes": [
"sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f",
"sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc"
"sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d",
"sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"
],
"markers": "python_version >= '3.8'",
"version": "==0.11.2"
"version": "==0.11.3"
},
"simplepro": {
"hashes": [

View File

@ -34,7 +34,13 @@ class GitConfigAdmin(AjaxAdmin):
class ProjectConfigAdmin(AjaxAdmin):
"""Admin配置"""
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"]
list_display = [
"project_id",
"project_name",
"project_secret",
"commands",
"is_enable",
]
readonly_fields = ["create_by", "delete_at", "detail"]
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'

View File

@ -16,4 +16,3 @@ class Command(BaseCommand):
print("初始化AI配置已创建")
else:
print("初始化AI配置已存在")

View File

@ -44,9 +44,7 @@ class GitConfig(BaseModel):
null=True, blank=True, max_length=16, verbose_name="Git名称"
)
git_type = fields.RadioField(
choices=constant.GIT_TYPE,
default=0,
verbose_name="Git类型"
choices=constant.GIT_TYPE, default=0, verbose_name="Git类型"
)
git_url = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="Git地址"
@ -67,6 +65,7 @@ class ProjectConfig(BaseModel):
"""
项目配置表
"""
git_config = fields.ForeignKey(
GitConfig,
null=True,
@ -89,10 +88,7 @@ class ProjectConfig(BaseModel):
max_length=256,
verbose_name="默认命令",
)
is_enable = fields.SwitchField(
default=True,
verbose_name="是否启用"
)
is_enable = fields.SwitchField(default=True, verbose_name="是否启用")
class Meta:
verbose_name = "项目配置"
@ -106,6 +102,7 @@ class ProjectHistory(BaseModel):
"""
项目历史表
"""
project = fields.ForeignKey(
ProjectConfig,
null=True,
@ -128,9 +125,7 @@ class ProjectHistory(BaseModel):
mr_title = fields.CharField(
null=True, blank=True, max_length=256, verbose_name="MR标题"
)
source_data = models.JSONField(
null=True, blank=True, verbose_name="源数据"
)
source_data = models.JSONField(null=True, blank=True, verbose_name="源数据")
class Meta:
verbose_name = "项目历史"

View File

@ -12,12 +12,7 @@ from utils import constant
def load_project_config(
git_url,
access_token,
project_secret,
openai_api_base,
openai_key,
llm_model
git_url, access_token, project_secret, openai_api_base, openai_key, llm_model
):
"""
加载项目配置
@ -36,12 +31,11 @@ def load_project_config(
"secret": project_secret,
"openai_api_base": openai_api_base,
"openai_key": openai_key,
"llm_model": llm_model
"llm_model": llm_model,
}
class WebHookView(View):
@staticmethod
def select_git_provider(git_type):
"""
@ -82,7 +76,9 @@ class WebHookView(View):
project_config = provider.get_project_config(project_id=project_id)
# Token 校验
provider.check_secret(request_headers=headers, project_secret=project_config.get("project_secret"))
provider.check_secret(
request_headers=headers, project_secret=project_config.get("project_secret")
)
provider.get_merge_request(
request_data=json_data,
@ -91,11 +87,13 @@ class WebHookView(View):
api_base=project_config.get("api_base"),
api_key=project_config.get("api_key"),
llm_model=project_config.get("llm_model"),
project_commands=project_config.get("commands")
project_commands=project_config.get("commands"),
)
# 记录请求日志: 目前仅记录合并日志
if json_data.get('object_kind') == 'merge_request':
provider.save_pr_agent_log(request_data=json_data, project_id=project_config.get("project_id"))
provider.save_pr_agent_log(
request_data=json_data, project_id=project_config.get("project_id")
)
return JsonResponse(status=200, data={"status": "ignored"})

View File

@ -1,8 +1,4 @@
GIT_TYPE = (
(0, "gitlab"),
(1, "github"),
(2, "gitea")
)
GIT_TYPE = ((0, "gitlab"), (1, "github"), (2, "gitea"))
DEFAULT_COMMANDS = (
("/review", "/review"),
@ -10,11 +6,7 @@ DEFAULT_COMMANDS = (
("/improve_code", "/improve_code"),
)
UA_TYPE = {
"GitLab": "gitlab",
"GitHub": "github",
"Go-http-client": "gitea"
}
UA_TYPE = {"GitLab": "gitlab", "GitHub": "github", "Go-http-client": "gitea"}
def get_git_type_from_ua(ua_value):

View File

@ -16,14 +16,14 @@ class GitProvider(ABC):
@abstractmethod
def get_merge_request(
self,
request_data,
git_url,
access_token,
api_base,
api_key,
llm_model,
project_commands
self,
request_data,
git_url,
access_token,
api_base,
api_key,
llm_model,
project_commands,
):
pass
@ -33,7 +33,6 @@ class GitProvider(ABC):
class GitLabProvider(GitProvider):
@staticmethod
def check_secret(request_headers, project_secret):
"""
@ -79,18 +78,18 @@ class GitLabProvider(GitProvider):
"access_token": git_config.access_token,
"project_secret": project_config.project_secret,
"commands": project_config.commands.split(","),
"project_id": project_config.id
"project_id": project_config.id,
}
def get_merge_request(
self,
request_data,
git_url,
access_token,
api_base,
api_key,
llm_model,
project_commands,
self,
request_data,
git_url,
access_token,
api_base,
api_key,
llm_model,
project_commands,
):
"""
实现GitLab Merge Request获取逻辑
@ -124,7 +123,10 @@ class GitLabProvider(GitProvider):
self.run_command(mr_url, project_commands)
# 数据库留存
return JsonResponse(status=200, data={"status": "review started"})
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"})
return JsonResponse(
status=400,
data={"error": "Merge request URL not found or action not open"},
)
@staticmethod
def save_pr_agent_log(request_data, project_id):
@ -134,13 +136,19 @@ class GitLabProvider(GitProvider):
:param project_id:
:return:
"""
if request_data.get('object_attributes', {}).get("source_branch") and request_data.get('object_attributes', {}).get("target_branch"):
if request_data.get('object_attributes', {}).get(
"source_branch"
) and request_data.get('object_attributes', {}).get("target_branch"):
models.ProjectHistory.objects.create(
project_id=project_id,
project_url=request_data.get("project", {}).get("web_url"),
mr_url=request_data.get('object_attributes', {}).get("url"),
source_branch=request_data.get('object_attributes', {}).get("source_branch"),
target_branch=request_data.get('object_attributes', {}).get("target_branch"),
source_branch=request_data.get('object_attributes', {}).get(
"source_branch"
),
target_branch=request_data.get('object_attributes', {}).get(
"target_branch"
),
mr_title=request_data.get('object_attributes', {}).get("title"),
source_data=request_data,
)

View File

@ -80,14 +80,20 @@ class PRAgent:
if action == "answer":
if notify:
notify()
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
await PRReviewer(
pr_url, is_answer=True, args=args, ai_handler=self.ai_handler
).run()
elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
await PRReviewer(
pr_url, is_auto=True, args=args, ai_handler=self.ai_handler
).run()
elif action in command2class:
if notify:
notify()
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
await command2class[action](
pr_url, ai_handler=self.ai_handler, args=args
).run()
else:
return False
return True

View File

@ -88,7 +88,7 @@ USER_MESSAGE_ONLY_MODELS = [
"deepseek/deepseek-reasoner",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview"
"o1-preview",
]
NO_SUPPORT_TEMPERATURE_MODELS = [
@ -99,5 +99,5 @@ NO_SUPPORT_TEMPERATURE_MODELS = [
"o1-2024-12-17",
"o3-mini",
"o3-mini-2025-01-31",
"o1-preview"
"o1-preview",
]

View File

@ -16,7 +16,14 @@ class BaseAiHandler(ABC):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
async def chat_completion(
self,
model: str,
system: str,
user: str,
temperature: float = 0.2,
img_path: str = None,
):
"""
This method should be implemented to return a chat completion from the AI model.
Args:

View File

@ -34,9 +34,16 @@ class LangChainOpenAIHandler(BaseAiHandler):
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
@retry(
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
async def chat_completion(
self, model: str, system: str, user: str, temperature: float = 0.2
):
try:
messages = [SystemMessage(content=system), HumanMessage(content=user)]
@ -45,7 +52,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
finish_reason = "completed"
return resp.content, finish_reason
except (Exception) as e:
except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
@ -66,7 +73,10 @@ class LangChainOpenAIHandler(BaseAiHandler):
if openai_api_base is None or len(openai_api_base) == 0:
return ChatOpenAI(openai_api_key=get_settings().openai.key)
else:
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base)
return ChatOpenAI(
openai_api_key=get_settings().openai.key,
openai_api_base=openai_api_base,
)
except AttributeError as e:
if getattr(e, "name"):
raise ValueError(f"OpenAI {e.name} is required") from e

View File

@ -36,9 +36,14 @@ class LiteLLMAIHandler(BaseAiHandler):
elif 'OPENAI_API_KEY' not in os.environ:
litellm.api_key = "dummy_key"
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
assert (
get_settings().aws.AWS_SECRET_ACCESS_KEY
and get_settings().aws.AWS_REGION_NAME
), "AWS credentials are incomplete"
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ[
"AWS_SECRET_ACCESS_KEY"
] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
if get_settings().get("litellm.use_client"):
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
@ -73,14 +78,19 @@ class LiteLLMAIHandler(BaseAiHandler):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
if (
get_settings().get("HUGGINGFACE.API_BASE", None)
and 'huggingface' in get_settings().config.model
):
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
self.repetition_penalty = float(
get_settings().huggingface.repetition_penalty
)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
@ -89,7 +99,9 @@ class LiteLLMAIHandler(BaseAiHandler):
# Google AI Studio
# SEE https://docs.litellm.ai/docs/providers/gemini
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
os.environ[
"GEMINI_API_KEY"
] = get_settings().google_ai_studio.gemini_api_key
# Support deepseek models
if get_settings().get("DEEPSEEK.KEY", None):
@ -140,27 +152,35 @@ class LiteLLMAIHandler(BaseAiHandler):
git_provider = get_settings().config.git_provider
metadata = dict()
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
callbacks = (
litellm.success_callback
+ litellm.failure_callback
+ litellm.service_callback
)
if "langfuse" in callbacks:
metadata.update({
"trace_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"trace_metadata": {
"command": command,
"pr_url": pr_url,
},
})
if "langsmith" in callbacks:
metadata.update({
"run_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"extra": {
"metadata": {
metadata.update(
{
"trace_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"trace_metadata": {
"command": command,
"pr_url": pr_url,
}
},
})
},
}
)
if "langsmith" in callbacks:
metadata.update(
{
"run_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"extra": {
"metadata": {
"command": command,
"pr_url": pr_url,
}
},
}
)
# Adding the captured logs to the kwargs
kwargs["metadata"] = metadata
@ -175,10 +195,19 @@ class LiteLLMAIHandler(BaseAiHandler):
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES)
retry=retry_if_exception_type(
(openai.APIError, openai.APIConnectionError, openai.APITimeoutError)
), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES),
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
async def chat_completion(
self,
model: str,
system: str,
user: str,
temperature: float = 0.2,
img_path: str = None,
):
try:
resp, finish_reason = None, None
deployment_id = self.deployment_id
@ -187,8 +216,12 @@ class LiteLLMAIHandler(BaseAiHandler):
if 'claude' in model and not system:
system = "No system prompt provided"
get_logger().warning(
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error."
)
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
if img_path:
try:
@ -201,14 +234,21 @@ class LiteLLMAIHandler(BaseAiHandler):
except Exception as e:
get_logger().error(f"Error fetching image: {img_path}", e)
return f"Error fetching image: {img_path}", "error"
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}}]
messages[1]["content"] = [
{"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}},
]
# Currently, some models do not support a separate system and user prompts
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
if (
model in self.user_message_only_models
or get_settings().config.custom_reasoning_model
):
user = f"{system}\n\n\n{user}"
system = ""
get_logger().info(f"Using model {model}, combining system and user prompts")
get_logger().info(
f"Using model {model}, combining system and user prompts"
)
messages = [{"role": "user", "content": user}]
kwargs = {
"model": model,
@ -227,7 +267,10 @@ class LiteLLMAIHandler(BaseAiHandler):
}
# Add temperature only if model supports it
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model:
if (
model not in self.no_support_temperature_models
and not get_settings().config.custom_reasoning_model
):
kwargs["temperature"] = temperature
if get_settings().litellm.get("enable_callbacks", False):
@ -235,7 +278,9 @@ class LiteLLMAIHandler(BaseAiHandler):
seed = get_settings().config.get("seed", -1)
if temperature > 0 and seed >= 0:
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0")
raise ValueError(
f"Seed ({seed}) is not supported with temperature ({temperature}) > 0"
)
elif seed >= 0:
get_logger().info(f"Using fixed seed of {seed}")
kwargs["seed"] = seed
@ -253,10 +298,10 @@ class LiteLLMAIHandler(BaseAiHandler):
except (openai.APIError, openai.APITimeoutError) as e:
get_logger().warning(f"Error during LLM inference: {e}")
raise
except (openai.RateLimitError) as e:
except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
except (Exception) as e:
except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e
if response is None or len(response["choices"]) == 0:
@ -267,7 +312,9 @@ class LiteLLMAIHandler(BaseAiHandler):
get_logger().debug(f"\nAI response:\n{resp}")
# log the full response for debugging
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
response_log = self.prepare_logs(
response, system, user, resp, finish_reason
)
get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging

View File

@ -37,13 +37,23 @@ class OpenAIHandler(BaseAiHandler):
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
@retry(
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
async def chat_completion(
self, model: str, system: str, user: str, temperature: float = 0.2
):
try:
get_logger().info("System: ", system)
get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
client = AsyncOpenAI()
chat_completion = await client.chat.completions.create(
model=model,
@ -53,15 +63,21 @@ class OpenAIHandler(BaseAiHandler):
resp = chat_completion.choices[0].message.content
finish_reason = chat_completion.choices[0].finish_reason
usage = chat_completion.usage
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
get_logger().info(
"AI response",
response=resp,
messages=messages,
finish_reason=finish_reason,
model=model,
usage=usage,
)
return resp, finish_reason
except (APIError, Timeout) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
except RateLimitError as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise

View File

@ -1,6 +1,7 @@
from base64 import b64decode
import hashlib
class CliArgs:
@staticmethod
def validate_user_args(args: list) -> (bool, str):
@ -23,12 +24,12 @@ class CliArgs:
for arg in args:
if arg.startswith('--'):
arg_word = arg.lower()
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
arg_word = arg_word.replace(
'__', '.'
) # replace double underscore with dot, e.g. --openai__key -> --openai.key
for forbidden_arg_word in forbidden_cli_args:
if forbidden_arg_word in arg_word:
return False, forbidden_arg_word
return True, ""
except Exception as e:
return False, str(e)

View File

@ -4,7 +4,7 @@ import re
from utils.pr_agent.config_loader import get_settings
def filter_ignored(files, platform = 'github'):
def filter_ignored(files, platform='github'):
"""
Filter out files that match the ignore patterns.
"""
@ -15,7 +15,9 @@ def filter_ignored(files, platform = 'github'):
if isinstance(patterns, str):
patterns = [patterns]
glob_setting = get_settings().ignore.glob
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
if isinstance(
glob_setting, str
): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
glob_setting = glob_setting.strip('[]').split(",")
patterns += [fnmatch.translate(glob) for glob in glob_setting]
@ -31,7 +33,9 @@ def filter_ignored(files, platform = 'github'):
if files and isinstance(files, list):
for r in compiled_patterns:
if platform == 'github':
files = [f for f in files if (f.filename and not r.match(f.filename))]
files = [
f for f in files if (f.filename and not r.match(f.filename))
]
elif platform == 'bitbucket':
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
files_o = []
@ -49,10 +53,18 @@ def filter_ignored(files, platform = 'github'):
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
files_o = []
for f in files:
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']):
if (
'new_path' in f
and f['new_path']
and not r.match(f['new_path'])
):
files_o.append(f)
continue
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']):
if (
'old_path' in f
and f['old_path']
and not r.match(f['old_path'])
):
files_o.append(f)
continue
files = files_o

View File

@ -8,9 +8,18 @@ from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
patch_extra_lines_after=0, filename: str = "") -> str:
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
def extend_patch(
original_file_str,
patch_str,
patch_extra_lines_before=0,
patch_extra_lines_after=0,
filename: str = "",
) -> str:
if (
not patch_str
or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0)
or not original_file_str
):
return patch_str
original_file_str = decode_if_bytes(original_file_str)
@ -21,10 +30,17 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
return patch_str
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
extended_patch_str = process_patch_lines(
patch_str,
original_file_str,
patch_extra_lines_before,
patch_extra_lines_after,
)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().warning(
f"Failed to extend patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return patch_str
return extended_patch_str
@ -48,13 +64,19 @@ def decode_if_bytes(original_file_str):
def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename:
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return any(
filename.endswith(skip_type) for skip_type in patch_extension_skip_types
)
return False
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
def process_patch_lines(
patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after
):
allow_dynamic_context = get_settings().config.allow_dynamic_context
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
patch_extra_lines_before_dynamic = (
get_settings().config.max_extra_lines_before_dynamic_context
)
original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
@ -63,59 +85,122 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
is_valid_hunk = True
start1, size1, start2, size2 = -1, -1, -1, -1
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try:
for i,line in enumerate(patch_lines):
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
# finish processing previous hunk
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
delta_lines = [
f' {line}'
for line in original_lines[
start1
+ size1
- 1 : start1
+ size1
- 1
+ patch_extra_lines_after
]
]
extended_patch_lines.extend(delta_lines)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
section_header, size1, size2, start1, start2 = extract_hunk_headers(
match
)
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
is_valid_hunk = check_if_hunk_lines_matches_to_file(
i, original_lines, patch_lines, start1
)
if is_valid_hunk and (
patch_extra_lines_before > 0 or patch_extra_lines_after > 0
):
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
def _calc_context_limits(patch_lines_before):
extended_start1 = max(1, start1 - patch_lines_before)
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
extended_size1 = (
size1
+ (start1 - extended_start1)
+ patch_extra_lines_after
)
extended_start2 = max(1, start2 - patch_lines_before)
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
if extended_start1 - 1 + extended_size1 > len_original_lines:
extended_size2 = (
size2
+ (start2 - extended_start2)
+ patch_extra_lines_after
)
if (
extended_start1 - 1 + extended_size1
> len_original_lines
):
# we cannot extend beyond the original file
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
delta_cap = (
extended_start1
- 1
+ extended_size1
- len_original_lines
)
extended_size1 = max(extended_size1 - delta_cap, size1)
extended_size2 = max(extended_size2 - delta_cap, size2)
return extended_start1, extended_size1, extended_start2, extended_size2
return (
extended_start1,
extended_size1,
extended_start2,
extended_size2,
)
if allow_dynamic_context:
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
(
extended_start1,
extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[
extended_start1 - 1 : start1 - 1
]
found_header = False
for i, line, in enumerate(lines_before):
for (
i,
line,
) in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
extended_start1, extended_start2 = (
extended_start1 + i,
extended_start2 + i,
)
extended_size1, extended_size2 = (
extended_size1 - i,
extended_size2 - i,
)
# get_logger().debug(f"Found section header in line {i} before the hunk")
section_header = ''
break
if not found_header:
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
(
extended_start1,
extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before)
else:
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
(
extended_start1,
extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before)
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
delta_lines = [
f' {line}'
for line in original_lines[extended_start1 - 1 : start1 - 1]
]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
@ -132,17 +217,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
extended_patch_lines.append('')
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
f'+{extended_start2},{extended_size2} @@ {section_header}'
)
extended_patch_lines.extend(delta_lines) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().warning(
f"Failed to extend patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return patch_str
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines = original_lines[
start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after
]
# add space at the beginning of each extra line
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines)
@ -158,11 +249,14 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
"""
is_valid_hunk = True
try:
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
if (
i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' '
): # an existing line in the file
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
is_valid_hunk = False
get_logger().error(
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content"
)
except:
pass
return is_valid_hunk
@ -195,8 +289,7 @@ def omit_deletion_hunks(patch_lines) -> str:
added_patched = []
add_hunk = False
inside_hunk = False
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines:
if line.startswith('@@'):
@ -221,8 +314,13 @@ def omit_deletion_hunks(patch_lines) -> str:
return '\n'.join(added_patched)
def handle_patch_deletions(patch: str, original_file_content_str: str,
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str:
def handle_patch_deletions(
patch: str,
original_file_content_str: str,
new_file_content_str: str,
file_name: str,
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN,
) -> str:
"""
Handle entire file or deletion patches.
@ -239,11 +337,13 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
str: The modified patch with deletion hunks omitted.
"""
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN):
if not new_file_content_str and (
edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN
):
# logic for handling deleted files - don't show patch, just show that the file was deleted
if get_settings().config.verbosity_level > 0:
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
patch = None # file was deleted
patch = None # file was deleted
else:
patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines)
@ -256,35 +356,35 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
"""
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Args:
patch (str): The patch string to be converted.
file: An object containing the filename of the file being patched.
Args:
patch (str): The patch string to be converted.
file: An object containing the filename of the file being patched.
Returns:
str: A string with line numbers for each hunk, indicating the new and old content of the file.
Returns:
str: A string with line numbers for each hunk, indicating the new and old content of the file.
example output:
## src/file.ts
__new hunk__
881 line1
882 line2
883 line3
887 + line4
888 + line5
889 line6
890 line7
...
__old hunk__
line1
line2
- line3
- line4
line5
line6
...
example output:
## src/file.ts
__new hunk__
881 line1
882 line2
883 line3
887 + line4
888 + line5
889 line6
890 line7
...
__old hunk__
line1
line2
- line3
- line4
line5
line6
...
"""
# if the file was deleted, return a message indicating that the file was deleted
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
@ -292,8 +392,7 @@ __old hunk__
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
@ -307,20 +406,32 @@ __old hunk__
if line.startswith('@@'):
header_line = line
match = RE_HUNK_HEADER.match(line)
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines
if match and (
new_content_lines or old_content_lines
): # found a new hunk, split the previous lines
if prev_header_line:
patch_with_lines_str += f'\n{prev_header_line}\n'
is_plus_lines = is_minus_lines = False
if new_content_lines:
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
is_plus_lines = any(
[line.startswith('+') for line in new_content_lines]
)
if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
is_minus_lines = any(
[line.startswith('-') for line in old_content_lines]
)
if (
is_plus_lines or is_minus_lines
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = (
patch_with_lines_str.rstrip() + '\n__new hunk__\n'
)
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if is_minus_lines:
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
patch_with_lines_str = (
patch_with_lines_str.rstrip() + '\n__old hunk__\n'
)
for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n"
new_content_lines = []
@ -335,8 +446,12 @@ __old hunk__
elif line.startswith('-'):
old_content_lines.append(line)
else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
if (
not line and line_i
): # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith(
'@@'
):
continue
elif line_i + 1 == len(patch_lines):
continue
@ -351,7 +466,9 @@ __old hunk__
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
if (
is_plus_lines or is_minus_lines
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
@ -363,13 +480,16 @@ __old hunk__
return patch_with_lines_str.rstrip()
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
def extract_hunk_lines_from_patch(
patch: str, file_name, line_start, line_end, side
) -> tuple[str, str]:
try:
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
selected_lines = ""
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
skip_hunk = False
@ -385,7 +505,9 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
match = RE_HUNK_HEADER.match(line)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
section_header, size1, size2, start1, start2 = extract_hunk_headers(
match
)
# check if line range is in this hunk
if side.lower() == 'left':
@ -400,15 +522,26 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
patch_with_lines_str += f'\n{header_line}\n'
elif not skip_hunk:
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end:
if (
side.lower() == 'right'
and line_start <= start2 + selected_lines_num <= line_end
):
selected_lines += line + '\n'
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end:
if (
side.lower() == 'left'
and start1 <= selected_lines_num + start1 <= line_end
):
selected_lines += line + '\n'
patch_with_lines_str += line + '\n'
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
if not line.startswith(
'-'
): # currently we don't support /ask line for deleted lines
selected_lines_num += 1
except Exception as e:
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Failed to extract hunk lines from patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return "", ""
return patch_with_lines_str.rstrip(), selected_lines.rstrip()

View File

@ -9,10 +9,14 @@ def filter_bad_extensions(files):
bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
return [
f
for f in files
if f.filename is not None and is_valid_file(f.filename, bad_extensions)
]
def is_valid_file(filename:str, bad_extensions=None) -> bool:
def is_valid_file(filename: str, bad_extensions=None) -> bool:
if not filename:
return False
if not bad_extensions:
@ -27,12 +31,16 @@ def sort_files_by_main_languages(languages: Dict, files: list):
Sort files by their main language, put the files that are in the main language first and the rest files after
"""
# sort languages by their size
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
languages_sorted_list = [
k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)
]
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages
main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
for language in languages_sorted_list:
if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()])
@ -62,7 +70,9 @@ def sort_files_by_main_languages(languages: Dict, files: list):
if extension_str in extensions:
tmp.append(file)
else:
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
if (file.filename not in rest_files) and (
extension_str not in main_extensions_flat
):
rest_files[file.filename] = file
if len(tmp) > 0:
files_sorted.append({"language": lang, "files": tmp})

View File

@ -7,18 +7,28 @@ from github import RateLimitExceededException
from utils.pr_agent.algo.file_filter import filter_ignored
from utils.pr_agent.algo.git_patch_processing import (
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
convert_to_hunks_with_lines_numbers,
extend_patch,
handle_patch_deletions,
)
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.types import EDIT_TYPE
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model
from utils.pr_agent.algo.utils import (
ModelType,
clip_tokens,
get_max_tokens,
get_weak_model,
)
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.log import get_logger
DELETED_FILES_ = "Deleted files:\n"
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
MORE_MODIFIED_FILES_ = (
"Additional modified files (insufficient token budget to process):\n"
)
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
@ -29,45 +39,59 @@ MAX_EXTRA_LINES = 10
def cap_and_log_extra_lines(value, direction) -> int:
if value > MAX_EXTRA_LINES:
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}")
get_logger().warning(
f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}"
)
return MAX_EXTRA_LINES
return value
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
large_pr_handling=False,
return_remaining_files=False):
def get_pr_diff(
git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
large_pr_handling=False,
return_remaining_files=False,
):
if disable_extra_lines:
PATCH_EXTRA_LINES_BEFORE = 0
PATCH_EXTRA_LINES_AFTER = 0
else:
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_BEFORE, "before"
)
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_AFTER, "after"
)
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
get_logger().info(
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
)
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e:
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
@ -76,24 +100,42 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
# generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
pr_languages,
token_handler,
add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff.")
get_logger().info(
f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff."
)
return "\n".join(patches_extended)
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
get_logger().info(
f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff."
)
(
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = pr_generate_compressed_diff(
pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling
)
if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
return "" # return empty string, as we want to generate multiple patches with a different prompt
get_logger().info(
f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff."
)
return "" # return empty string, as we want to generate multiple patches with a different prompt
# return the first patch
patches_compressed = patches_compressed_list[0]
@ -144,26 +186,37 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
get_logger().debug(
f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}"
)
if not return_remaining_files:
return final_diff
else:
return final_diff, remaining_files_list
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
def get_pr_diff_multiple_patchs(
git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
):
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
get_logger().info(
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
)
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
@ -171,24 +224,47 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e:
pass
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
(
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = pr_generate_compressed_diff(
pr_languages,
token_handler,
model,
add_line_numbers_to_hunks,
large_pr_handling=True,
)
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
return (
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
)
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
patch_extra_lines_before: int = 0,
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
def pr_generate_extended_diff(
pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
patch_extra_lines_before: int = 0,
patch_extra_lines_after: int = 0,
) -> Tuple[list, int, list]:
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
@ -200,20 +276,33 @@ def pr_generate_extended_diff(pr_languages: list,
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch,
patch_extra_lines_before, patch_extra_lines_after, file.filename)
extended_patch = extend_patch(
original_file_content_str,
patch,
patch_extra_lines_before,
patch_extra_lines_after,
file.filename,
)
if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
get_logger().warning(
f"Failed to extend patch for file: {file.filename}"
)
continue
if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
full_extended_patch = convert_to_hunks_with_lines_numbers(
extended_patch, file
)
else:
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
if file.ai_file_summary and get_settings().get(
"config.enable_ai_metadata", False
):
full_extended_patch = add_ai_summary_top_patch(
file, full_extended_patch
)
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
@ -224,9 +313,13 @@ def pr_generate_extended_diff(pr_languages: list,
return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool,
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
def pr_generate_compressed_diff(
top_langs: list,
token_handler: TokenHandler,
model: str,
convert_hunks_to_line_numbers: bool,
large_pr_handling: bool,
) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff
@ -244,8 +337,13 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
continue
# removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename, file.edit_type)
patch = handle_patch_deletions(
patch,
original_file_content_str,
new_file_content_str,
file.filename,
file.edit_type,
)
if patch is None:
if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename)
@ -259,30 +357,54 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
# patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch)
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
file_dict[file.filename] = {
'patch': patch,
'tokens': new_patch_tokens,
'edit_type': file.edit_type,
}
max_tokens_model = get_max_tokens(model)
# first iteration
files_in_patches_list = []
remaining_files_list = [file.filename for file in sorted_files]
patches_list =[]
remaining_files_list = [file.filename for file in sorted_files]
patches_list = []
total_tokens_list = []
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
max_tokens_model, remaining_files_list, token_handler)
(
total_tokens,
patches,
remaining_files_list,
files_in_patch_list,
) = generate_full_patch(
convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list,
token_handler,
)
patches_list.append(patches)
total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed)
if large_pr_handling:
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
NUMBER_OF_ALLOWED_ITERATIONS = (
get_settings().pr_description.max_ai_calls - 1
) # one more call is to summarize
for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1):
if remaining_files_list:
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list, token_handler)
(
total_tokens,
patches,
remaining_files_list,
files_in_patch_list,
) = generate_full_patch(
convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list,
token_handler,
)
if patches:
patches_list.append(patches)
total_tokens_list.append(total_tokens)
@ -290,11 +412,24 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
else:
break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
return (
patches_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
)
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
total_tokens = token_handler.prompt_tokens # initial tokens
def generate_full_patch(
convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list_prev,
token_handler,
):
total_tokens = token_handler.prompt_tokens # initial tokens
patches = []
remaining_files_list_new = []
files_in_patch_list = []
@ -312,7 +447,10 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if (
total_tokens + new_patch_tokens
> max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
@ -334,7 +472,9 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
async def retry_with_fallback_models(
f: Callable, model_type: ModelType = ModelType.REGULAR
):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
@ -347,11 +487,11 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
get_settings().set("openai.deployment_id", deployment_id)
return await f(model)
except:
get_logger().warning(
f"Failed to generate prediction with {model}"
)
get_logger().warning(f"Failed to generate prediction with {model}")
if i == len(all_models) - 1: # If it's the last iteration
raise Exception(f"Failed to generate prediction with any model of {all_models}")
raise Exception(
f"Failed to generate prediction with any model of {all_models}"
)
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
@ -374,17 +514,21 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})")
raise ValueError(
f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})"
)
else:
all_deployments = [deployment_id] * len(all_models)
return all_deployments
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
def get_pr_multi_diffs(
git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5,
) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
@ -404,13 +548,17 @@ def get_pr_multi_diffs(git_provider: GitProvider,
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise
diff_files = filter_ignored(diff_files)
# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
# Sort files within each language group by tokens in descending order
sorted_files = []
@ -420,14 +568,19 @@ def get_pr_multi_diffs(git_provider: GitProvider,
# Get the maximum number of extra lines before and after the patch
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_BEFORE, "before"
)
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True,
pr_languages,
token_handler,
add_line_numbers_to_hunks=True,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
@ -450,27 +603,50 @@ def get_pr_multi_diffs(git_provider: GitProvider,
continue
# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
patch = handle_patch_deletions(
patch,
original_file_content_str,
new_file_content_str,
file.filename,
file.edit_type,
)
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
if file.ai_file_summary and get_settings().get(
"config.enable_ai_metadata", False
):
patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if (
patch
and (token_handler.prompt_tokens + new_patch_tokens)
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
delta_tokens = (
get_max_tokens(model)
- OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
- token_handler.prompt_tokens
)
patch_clipped = clip_tokens(
patch,
delta_tokens,
delete_last_line=True,
num_input_tokens=new_patch_tokens,
)
new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if (
patch_clipped
and (token_handler.prompt_tokens + new_patch_tokens)
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
else:
@ -480,13 +656,16 @@ def get_pr_multi_diffs(git_provider: GitProvider,
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
if patch and (
total_tokens + new_patch_tokens
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if call_number > max_calls: # avoid creating new patches
if call_number > max_calls: # avoid creating new patches
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Reached max calls ({max_calls})")
break
@ -497,7 +676,9 @@ def get_pr_multi_diffs(git_provider: GitProvider,
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
get_logger().info(
f"Tokens: {total_tokens}, last filename: {file.filename}"
)
# Add the last chunk
if patches:
@ -515,7 +696,10 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
if not pr_description_files:
get_logger().warning(f"PR description files are empty.")
return
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files}
available_files = {
pr_file['full_file_name'].strip(): pr_file
for pr_file in pr_description_files
}
diff_files = git_provider.get_diff_files()
found_any_match = False
for file in diff_files:
@ -524,11 +708,15 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
file.ai_file_summary = available_files[filename]
found_any_match = True
if not found_any_match:
get_logger().error(f"Failed to find any matching files between PR description and diff files.",
artifact={"pr_description_files": pr_description_files})
get_logger().error(
f"Failed to find any matching files between PR description and diff files.",
artifact={"pr_description_files": pr_description_files},
)
except Exception as e:
get_logger().error(f"Failed to add AI metadata to diff files: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Failed to add AI metadata to diff files: {e}",
artifact={"traceback": traceback.format_exc()},
)
def add_ai_summary_top_patch(file, full_extended_patch):
@ -537,14 +725,18 @@ def add_ai_summary_top_patch(file, full_extended_patch):
full_extended_patch_lines = full_extended_patch.split("\n")
for i, line in enumerate(full_extended_patch_lines):
if line.startswith("## File:") or line.startswith("## file:"):
full_extended_patch_lines.insert(i + 1,
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}")
full_extended_patch_lines.insert(
i + 1,
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}",
)
full_extended_patch = "\n".join(full_extended_patch_lines)
return full_extended_patch
# if no '## File: ...' was found
return full_extended_patch
except Exception as e:
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Failed to add AI summary to the top of the patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return full_extended_patch

View File

@ -15,12 +15,17 @@ class TokenEncoder:
@classmethod
def get_token_encoder(cls):
model = get_settings().config.model
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance
if (
cls._encoder_instance is None or model != cls._model
): # Check without acquiring the lock for performance
with cls._lock: # Lock acquisition to ensure thread safety
if cls._encoder_instance is None or model != cls._model:
cls._model = model
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
"cl100k_base")
cls._encoder_instance = (
encoding_for_model(cls._model)
if "gpt" in cls._model
else get_encoding("cl100k_base")
)
return cls._encoder_instance
@ -49,7 +54,9 @@ class TokenHandler:
"""
self.encoder = TokenEncoder.get_token_encoder()
if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
self.prompt_tokens = self._get_system_user_tokens(
pr, self.encoder, vars, system, user
)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
"""

View File

@ -41,10 +41,12 @@ class Range(BaseModel):
column_start: int = -1
column_end: int = -1
class ModelType(str, Enum):
REGULAR = "regular"
WEAK = "weak"
class PRReviewHeader(str, Enum):
REGULAR = "## PR 评审指南"
INCREMENTAL = "## 增量 PR 评审指南"
@ -57,7 +59,9 @@ class PRDescriptionHeader(str, Enum):
def get_setting(key: str) -> Any:
try:
key = key.upper()
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
return context.get("settings", global_settings).get(
key, global_settings.get(key, None)
)
except Exception:
return global_settings.get(key, None)
@ -72,14 +76,29 @@ def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str
# Everything before the colon (inclusive) is wrapped in <strong> tags
if only_markdown:
if reference_link:
transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:]
transformed_string = (
f"[**{text[:colon_position + 1]}**]({reference_link})\n"
+ text[colon_position + 1 :]
)
else:
transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:]
transformed_string = (
f"**{text[:colon_position + 1]}**\n"
+ text[colon_position + 1 :]
)
else:
if reference_link:
transformed_string = f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>" + text[colon_position + 1:]
transformed_string = (
f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>"
+ text[colon_position + 1 :]
)
else:
transformed_string = "<strong>" + text[:colon_position + 1] + "</strong>" +'<br>' + text[colon_position + 1:]
transformed_string = (
"<strong>"
+ text[: colon_position + 1]
+ "</strong>"
+ '<br>'
+ text[colon_position + 1 :]
)
else:
# If there's no ": ", return the original string
transformed_string = text
@ -101,11 +120,14 @@ def unique_strings(input_list: List[str]) -> List[str]:
seen.add(item)
return unique_list
def convert_to_markdown_v2(output_data: dict,
gfm_supported: bool = True,
incremental_review=None,
git_provider=None,
files=None) -> str:
def convert_to_markdown_v2(
output_data: dict,
gfm_supported: bool = True,
incremental_review=None,
git_provider=None,
files=None,
) -> str:
"""
Convert a dictionary of data into markdown format.
Args:
@ -183,7 +205,9 @@ def convert_to_markdown_v2(output_data: dict,
else:
markdown_text += f"### {emoji} PR 包含测试\n\n"
elif 'ticket compliance check' in key_nice.lower():
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
markdown_text = ticket_markdown_logic(
emoji, markdown_text, value, gfm_supported
)
elif 'security concerns' in key_nice.lower():
if gfm_supported:
markdown_text += f"<tr><td>"
@ -220,7 +244,9 @@ def convert_to_markdown_v2(output_data: dict,
if gfm_supported:
markdown_text += f"<tr><td>"
# markdown_text += f"{emoji}&nbsp;<strong>{key_nice}</strong><br><br>\n\n"
markdown_text += f"{emoji}&nbsp;<strong>建议评审的重点领域</strong><br><br>\n\n"
markdown_text += (
f"{emoji}&nbsp;<strong>建议评审的重点领域</strong><br><br>\n\n"
)
else:
markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n"
for i, issue in enumerate(issues):
@ -235,9 +261,13 @@ def convert_to_markdown_v2(output_data: dict,
start_line = int(str(issue.get('start_line', 0)).strip())
end_line = int(str(issue.get('end_line', 0)).strip())
relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True)
relevant_lines_str = extract_relevant_lines_str(
end_line, files, relevant_file, start_line, dedent=True
)
if git_provider:
reference_link = git_provider.get_line_link(relevant_file, start_line, end_line)
reference_link = git_provider.get_line_link(
relevant_file, start_line, end_line
)
else:
reference_link = None
@ -256,7 +286,9 @@ def convert_to_markdown_v2(output_data: dict,
issue_str = f"**{issue_header}**\n\n{issue_content}\n\n"
markdown_text += f"{issue_str}\n\n"
except Exception as e:
get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}")
get_logger().exception(
f"Failed to process 'Recommended focus areas for review': {e}"
)
if gfm_supported:
markdown_text += f"</td></tr>\n"
else:
@ -273,7 +305,9 @@ def convert_to_markdown_v2(output_data: dict,
return markdown_text
def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str:
def extract_relevant_lines_str(
end_line, files, relevant_file, start_line, dedent=False
) -> str:
"""
Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content.
"""
@ -286,10 +320,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
if not file.head_file:
# as a fallback, extract relevant lines directly from patch
patch = file.patch
get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead")
_, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right')
get_logger().info(
f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead"
)
_, selected_lines = extract_hunk_lines_from_patch(
patch, file.filename, start_line, end_line, side='right'
)
if not selected_lines:
get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}")
get_logger().error(
f"Failed to extract relevant lines from patch: {file.filename}"
)
return ""
# filter out '-' lines
relevant_lines_str = ""
@ -299,12 +339,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
relevant_lines_str += line[1:] + '\n'
else:
relevant_file_lines = file.head_file.splitlines()
relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line])
relevant_lines_str = "\n".join(
relevant_file_lines[start_line - 1 : end_line]
)
if dedent and relevant_lines_str:
# Remove the longest leading string of spaces and tabs common to all lines.
relevant_lines_str = textwrap.dedent(relevant_lines_str)
relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```"
relevant_lines_str = (
f"```{file.language}\n{relevant_lines_str}\n```"
)
break
return relevant_lines_str
@ -325,14 +369,21 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
ticket_url = ticket_analysis.get('ticket_url', '').strip()
explanation = ''
ticket_compliance_level = '' # Individual ticket compliance
fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip()
not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip()
requires_further_human_verification = ticket_analysis.get('requires_further_human_verification',
'').strip()
fully_compliant_str = ticket_analysis.get(
'fully_compliant_requirements', ''
).strip()
not_compliant_str = ticket_analysis.get(
'not_compliant_requirements', ''
).strip()
requires_further_human_verification = ticket_analysis.get(
'requires_further_human_verification', ''
).strip()
if not fully_compliant_str and not not_compliant_str:
get_logger().debug(f"Ticket compliance has no requirements",
artifact={'ticket_url': ticket_url})
get_logger().debug(
f"Ticket compliance has no requirements",
artifact={'ticket_url': ticket_url},
)
continue
# Calculate individual ticket compliance level
@ -353,19 +404,27 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
# build compliance string
if fully_compliant_str:
explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
explanation += (
f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
)
if not_compliant_str:
explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
explanation += (
f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
)
if requires_further_human_verification:
explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n"
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n"
# for debugging
if requires_further_human_verification:
get_logger().debug(f"Ticket compliance requires further human verification",
artifact={'ticket_url': ticket_url,
'requires_further_human_verification': requires_further_human_verification,
'compliance_level': ticket_compliance_level})
get_logger().debug(
f"Ticket compliance requires further human verification",
artifact={
'ticket_url': ticket_url,
'requires_further_human_verification': requires_further_human_verification,
'compliance_level': ticket_compliance_level,
},
)
except Exception as e:
get_logger().exception(f"Failed to process ticket compliance: {e}")
@ -381,7 +440,10 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
compliance_emoji = ''
elif any(level == 'Not compliant' for level in all_compliance_levels):
# If there's a mix of compliant and non-compliant tickets
if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels):
if any(
level in ['Fully compliant', 'PR Code Verified']
for level in all_compliance_levels
):
compliance_level = 'Partially compliant'
compliance_emoji = '🔶'
else:
@ -395,7 +457,9 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
compliance_emoji = ''
# Set extra statistics outside the ticket loop
get_settings().set('config.extra_statistics', {'compliance_level': compliance_level})
get_settings().set(
'config.extra_statistics', {'compliance_level': compliance_level}
)
# editing table row for ticket compliance analysis
if gfm_supported:
@ -425,7 +489,9 @@ def process_can_be_split(emoji, value):
for i, split in enumerate(value):
title = split.get('title', '')
relevant_files = split.get('relevant_files', [])
markdown_text += f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
markdown_text += (
f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
)
markdown_text += f"___\n\n相关文件:\n\n"
for file in relevant_files:
markdown_text += f"- {file}\n"
@ -464,7 +530,9 @@ def process_can_be_split(emoji, value):
return markdown_text
def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str:
def parse_code_suggestion(
code_suggestion: dict, i: int = 0, gfm_supported: bool = True
) -> str:
"""
Convert a dictionary of data into markdown format.
@ -484,15 +552,19 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>"
# continue
elif sub_key.lower() == 'suggestion':
markdown_text += (f"<tr><td>{sub_key} &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>"
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>")
markdown_text += (
f"<tr><td>{sub_key} &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>"
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>"
)
elif sub_key.lower() == 'relevant_line':
markdown_text += f"<tr><td>相关行</td>"
sub_value_list = sub_value.split('](')
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
if len(sub_value_list) > 1:
link = sub_value_list[1].rstrip(')').strip('`')
markdown_text += f"<td><a href='{link}'>{relevant_line}</a></td>"
markdown_text += (
f"<td><a href='{link}'>{relevant_line}</a></td>"
)
else:
markdown_text += f"<td>{relevant_line}</td>"
markdown_text += "</tr>"
@ -505,11 +577,14 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
for sub_key, sub_value in code_suggestion.items():
if isinstance(sub_key, str):
sub_key = sub_key.rstrip()
if isinstance(sub_value,str):
if isinstance(sub_value, str):
sub_value = sub_value.rstrip()
if isinstance(sub_value, dict): # "code example"
markdown_text += f" - **{sub_key}:**\n"
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
for (
code_key,
code_value,
) in sub_value.items(): # 'before' and 'after' code
code_str = f"```\n{code_value}\n```"
code_str_indented = textwrap.indent(code_str, ' ')
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
@ -520,7 +595,9 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
markdown_text += f" **{sub_key}:** {sub_value} \n"
if "relevant_line" not in sub_key.lower(): # nicer presentation
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker
markdown_text = (
markdown_text.rstrip('\n') + " \n"
) # works for gitlab and bitbucker
markdown_text += "\n"
return markdown_text
@ -561,9 +638,15 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
else:
closing_bracket = "]}}"
if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \
(review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) :
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
if (
review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0
) or (
review.rfind("'Code suggestions': [") > 0
or review.rfind('"Code suggestions": [') > 0
):
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][
-1
] - 1
valid_json = False
iter_count = 0
@ -574,7 +657,9 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
review = review[:last_code_suggestion_ind].strip() + closing_bracket
except json.decoder.JSONDecodeError:
review = review[:last_code_suggestion_ind]
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
last_code_suggestion_ind = [
m.end() for m in re.finditer(r"\}\s*,", review)
][-1] - 1
iter_count += 1
if not valid_json:
@ -629,7 +714,12 @@ def convert_str_to_datetime(date_str):
return datetime.strptime(date_str, datetime_format)
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str:
def load_large_diff(
filename,
new_file_content_str: str,
original_file_content_str: str,
show_warning: bool = True,
) -> str:
"""
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input.
@ -640,10 +730,14 @@ def load_large_diff(filename, new_file_content_str: str, original_file_content_s
try:
original_file_content_str = (original_file_content_str or "").rstrip() + "\n"
new_file_content_str = (new_file_content_str or "").rstrip() + "\n"
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
diff = difflib.unified_diff(
original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True),
)
if get_settings().config.verbosity_level >= 2 and show_warning:
get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.")
get_logger().info(
f"File was modified, but no patch was found. Manually creating patch: {filename}."
)
patch = ''.join(diff)
return patch
except Exception as e:
@ -693,42 +787,68 @@ def _fix_key_value(key: str, value: str):
try:
value = yaml.safe_load(value)
except Exception as e:
get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e)
get_logger().debug(
f"Failed to parse YAML for config override {key}={value}", exc_info=e
)
return key, value
def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict:
response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
def load_yaml(
response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key=""
) -> dict:
response_text = (
response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
)
try:
data = yaml.safe_load(response_text)
except Exception as e:
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
data = try_fix_yaml(
response_text,
keys_fix_yaml=keys_fix_yaml,
first_key=first_key,
last_key=last_key,
)
if not data:
get_logger().error(f"Failed to parse AI prediction after fallbacks",
artifact={'response_text': response_text})
get_logger().error(
f"Failed to parse AI prediction after fallbacks",
artifact={'response_text': response_text},
)
else:
get_logger().info(f"Successfully parsed AI prediction after fallbacks",
artifact={'response_text': response_text})
get_logger().info(
f"Successfully parsed AI prediction after fallbacks",
artifact={'response_text': response_text},
)
return data
def try_fix_yaml(response_text: str,
keys_fix_yaml: List[str] = [],
first_key="",
last_key="",) -> dict:
def try_fix_yaml(
response_text: str,
keys_fix_yaml: List[str] = [],
first_key="",
last_key="",
) -> dict:
response_text_lines = response_text.split('\n')
keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:']
keys_yaml = [
'relevant line:',
'suggestion content:',
'relevant file:',
'existing code:',
'improved code:',
]
keys_yaml = keys_yaml + keys_fix_yaml
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
response_text_lines_copy = response_text_lines.copy()
for i in range(0, len(response_text_lines_copy)):
for key in keys_yaml:
if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]:
response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}',
f'{key} |\n ')
if (
key in response_text_lines_copy[i]
and not '|' in response_text_lines_copy[i]
):
response_text_lines_copy[i] = response_text_lines_copy[i].replace(
f'{key}', f'{key} |\n '
)
try:
data = yaml.safe_load('\n'.join(response_text_lines_copy))
get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
@ -743,22 +863,26 @@ def try_fix_yaml(response_text: str,
snippet_text = snippet.group()
try:
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
get_logger().info(
f"Successfully parsed AI prediction after extracting yaml snippet"
)
return data
except:
pass
# third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
response_text_copy = (
response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
)
try:
data = yaml.safe_load(response_text_copy)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
get_logger().info(
f"Successfully parsed AI prediction after removing curly brackets"
)
return data
except:
pass
# forth fallback - try to extract yaml snippet by 'first_key' and 'last_key'
# note that 'last_key' can be in practice a key that is not the last key in the yaml snippet.
# it just needs to be some inner key, so we can look for newlines after it
@ -767,13 +891,23 @@ def try_fix_yaml(response_text: str,
if index_start == -1:
index_start = response_text.find(f"{first_key}:")
index_last_code = response_text.rfind(f"{last_key}:")
index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key
index_end = response_text.find(
"\n\n", index_last_code
) # look for newlines after last_key
if index_end == -1:
index_end = len(response_text)
response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip()
response_text_copy = (
response_text[index_start:index_end]
.strip()
.strip('```yaml')
.strip('`')
.strip()
)
try:
data = yaml.safe_load(response_text_copy)
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
get_logger().info(
f"Successfully parsed AI prediction after extracting yaml snippet"
)
return data
except:
pass
@ -784,7 +918,9 @@ def try_fix_yaml(response_text: str,
response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:]
try:
data = yaml.safe_load('\n'.join(response_text_lines_copy))
get_logger().info(f"Successfully parsed AI prediction after removing leading '+'")
get_logger().info(
f"Successfully parsed AI prediction after removing leading '+'"
)
return data
except:
pass
@ -794,7 +930,9 @@ def try_fix_yaml(response_text: str,
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
try:
data = yaml.safe_load(response_text_lines_tmp)
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines")
get_logger().info(
f"Successfully parsed AI prediction after removing {i} lines"
)
return data
except:
pass
@ -820,11 +958,14 @@ def set_custom_labels(variables, git_provider=None):
for k, v in labels.items():
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
variables[
"custom_labels_class"
] += f"\n {k.lower().replace(' ', '_')} = {description}"
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
counter += 1
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
def get_user_labels(current_labels: List[str] = None):
"""
Only keep labels that has been added by the user
@ -866,14 +1007,22 @@ def get_max_tokens(model):
elif settings.config.custom_model_max_tokens > 0:
max_tokens_model = settings.config.custom_model_max_tokens
else:
raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens")
raise Exception(
f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens"
)
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
return max_tokens_model
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
def clip_tokens(
text: str,
max_tokens: int,
add_three_dots=True,
num_input_tokens=None,
delete_last_line=False,
) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
@ -909,14 +1058,15 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
clipped_text = clipped_text.rsplit('\n', 1)[0]
if add_three_dots:
clipped_text += "\n...(truncated)"
else: # if the text is empty
clipped_text = ""
else: # if the text is empty
clipped_text = ""
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
return text
def replace_code_tags(text):
"""
Replace odd instances of ` with <code> and even instances of ` with </code>
@ -928,15 +1078,16 @@ def replace_code_tags(text):
return ''.join(parts)
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None) -> Tuple[int, int]:
def find_line_number_of_relevant_line_in_file(
diff_files: List[FilePatchInfo],
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
) -> Tuple[int, int]:
position = -1
if absolute_position is None:
absolute_position = -1
re_hunk_header = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
re_hunk_header = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
if not diff_files:
return position, absolute_position
@ -947,7 +1098,7 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
patch_lines = patch.splitlines()
delta = 0
start1, size1, start2, size2 = 0, 0, 0, 0
if absolute_position != -1: # matching absolute to relative
if absolute_position != -1: # matching absolute to relative
for i, line in enumerate(patch_lines):
# new hunk
if line.startswith('@@'):
@ -965,12 +1116,12 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break
else:
# try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
patch_lines, n=3, cutoff=0.93)
matches_difflib: list[str | Any] = difflib.get_close_matches(
relevant_line_in_file, patch_lines, n=3, cutoff=0.93
)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0]
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
delta = 0
@ -1002,19 +1153,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break
return position, absolute_position
def get_rate_limit_status(github_token) -> dict:
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
GITHUB_API_URL = (
get_settings(use_context=False)
.get("GITHUB.BASE_URL", "https://api.github.com")
.rstrip("/")
) # "https://api.github.com"
# GITHUB_API_URL = "https://api.github.com"
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
HEADERS = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}"
"Authorization": f"token {github_token}",
}
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
try:
rate_limit_info = response.json()
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
if (
rate_limit_info.get('message') == 'Rate limiting is not enabled.'
): # for github enterprise
return {'resources': {}}
response.raise_for_status() # Check for HTTP errors
except: # retry
@ -1024,12 +1182,16 @@ def get_rate_limit_status(github_token) -> dict:
return rate_limit_info
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
def validate_rate_limit_github(
github_token, installation_id=None, threshold=0.1
) -> bool:
try:
rate_limit_status = get_rate_limit_status(github_token)
if installation_id:
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
# validate that the rate limit is not exceeded
get_logger().debug(
f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}"
)
# validate that the rate limit is not exceeded
# validate that the rate limit is not exceeded
for key, value in rate_limit_status['resources'].items():
if value['remaining'] < value['limit'] * threshold:
@ -1037,8 +1199,9 @@ def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1
return False
return True
except Exception as e:
get_logger().error(f"Error in rate limit {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error in rate limit {e}", artifact={"traceback": traceback.format_exc()}
)
return True
@ -1051,7 +1214,9 @@ def validate_and_await_rate_limit(github_token):
get_logger().error(f"key: {key}, value: {value}")
sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
get_logger().error(
f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours"
)
if sleep_time_sec > 0:
time.sleep(sleep_time_sec + 1)
rate_limit_status = get_rate_limit_status(github_token)
@ -1068,22 +1233,39 @@ def github_action_output(output_data: dict, key_name: str):
key_data = output_data.get(key_name, {})
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
print(
f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}",
file=fh,
)
except Exception as e:
get_logger().error(f"Failed to write to GitHub Action output: {e}")
return
def show_relevant_configurations(relevant_section: str) -> str:
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
skip_keys = [
'ai_disclaimer',
'ai_disclaimer_title',
'ANALYTICS_FOLDER',
'secret_provider',
"skip_keys",
"app_id",
"redirect",
'trial_prefix_message',
'no_eligible_message',
'identity_provider',
'ALLOWED_REPOS',
'APP_NAME',
]
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys:
skip_keys.extend(extra_skip_keys)
markdown_text = ""
markdown_text += "\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
markdown_text +="<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
markdown_text += (
"\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
)
markdown_text += "<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
markdown_text += f"**[config**]\n```yaml\n\n"
for key, value in get_settings().config.items():
if key in skip_keys:
@ -1099,6 +1281,7 @@ def show_relevant_configurations(relevant_section: str) -> str:
markdown_text += "\n</details>\n"
return markdown_text
def is_value_no(value):
if not value:
return True
@ -1122,7 +1305,7 @@ def string_to_uniform_number(s: str) -> float:
# Convert the hash to an integer
hash_int = int(hash_object.hexdigest(), 16)
# Normalize the integer to the range [0, 1]
max_hash_int = 2 ** 256 - 1
max_hash_int = 2**256 - 1
uniform_number = float(hash_int) / max_hash_int
return uniform_number
@ -1131,7 +1314,9 @@ def process_description(description_full: str) -> Tuple[str, List]:
if not description_full:
return "", []
description_split = description_full.split(PRDescriptionHeader.CHANGES_WALKTHROUGH.value)
description_split = description_full.split(
PRDescriptionHeader.CHANGES_WALKTHROUGH.value
)
base_description_str = description_split[0]
changes_walkthrough_str = ""
files = []
@ -1167,45 +1352,58 @@ def process_description(description_full: str) -> Tuple[str, List]:
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong><dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\n\n\s*(.*?)</details>'
res = re.search(pattern_back, file_data, re.DOTALL)
if not res or res.lastindex != 4:
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
res = re.search(pattern_back, file_data, re.DOTALL)
if res and res.lastindex == 4:
short_filename = res.group(1).strip()
short_summary = res.group(2).strip()
long_filename = res.group(3).strip()
long_summary = res.group(4).strip()
long_summary = long_summary.replace('<br> *', '\n*').replace('<br>','').replace('\n','<br>')
long_summary = res.group(4).strip()
long_summary = (
long_summary.replace('<br> *', '\n*')
.replace('<br>', '')
.replace('\n', '<br>')
)
long_summary = h.handle(long_summary).strip()
if long_summary.startswith('\\-'):
long_summary = "* " + long_summary[2:]
elif not long_summary.startswith('*'):
long_summary = f"* {long_summary}"
files.append({
'short_file_name': short_filename,
'full_file_name': long_filename,
'short_summary': short_summary,
'long_summary': long_summary
})
files.append(
{
'short_file_name': short_filename,
'full_file_name': long_filename,
'short_summary': short_summary,
'long_summary': long_summary,
}
)
else:
if '<code>...</code>' in file_data:
pass # PR with many files. some did not get analyzed
pass # PR with many files. some did not get analyzed
else:
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
get_logger().error(
f"Failed to parse description",
artifact={'description': file_data},
)
except Exception as e:
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data})
get_logger().exception(
f"Failed to process description: {e}",
artifact={'description': file_data},
)
except Exception as e:
get_logger().exception(f"Failed to process description: {e}")
return base_description_str, files
def get_version() -> str:
# First check pyproject.toml if running directly out of repository
if os.path.exists("pyproject.toml"):
if sys.version_info >= (3, 11):
import tomllib
with open("pyproject.toml", "rb") as f:
data = tomllib.load(f)
if "project" in data and "version" in data["project"]:
@ -1213,7 +1411,9 @@ def get_version() -> str:
else:
get_logger().warning("Version not found in pyproject.toml")
else:
get_logger().warning("Unable to determine local version from pyproject.toml")
get_logger().warning(
"Unable to determine local version from pyproject.toml"
)
# Otherwise get the installed pip package version
try:

View File

@ -12,8 +12,9 @@ setup_logger(log_level)
def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\
parser = argparse.ArgumentParser(
description='AI based pull request analyzer',
usage="""\
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
For example:
- cli.py --pr_url=... review
@ -45,11 +46,20 @@ def set_parser():
Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""")
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
""",
)
parser.add_argument(
'--version', action='version', version=f'pr-agent {get_version()}'
)
parser.add_argument(
'--pr_url', type=str, help='The URL of the PR to review', default=None
)
parser.add_argument(
'--issue_url', type=str, help='The URL of the Issue to review', default=None
)
parser.add_argument(
'command', type=str, help='The', choices=commands, default='review'
)
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser
@ -76,14 +86,24 @@ def run(inargs=None, args=None):
async def inner():
if args.issue_url:
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest))
result = await asyncio.create_task(
PRAgent().handle_request(args.issue_url, [command] + args.rest)
)
else:
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest))
result = await asyncio.create_task(
PRAgent().handle_request(args.pr_url, [command] + args.rest)
)
if get_settings().litellm.get("enable_callbacks", False):
# There may be additional events on the event queue from the run above. If there are give them time to complete.
get_logger().debug("Waiting for event queue to complete")
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()])
await asyncio.wait(
[
task
for task in asyncio.all_tasks()
if task is not asyncio.current_task()
]
)
return result

View File

@ -7,7 +7,9 @@ def main():
provider = "github" # GitHub provider
user_token = "..." # GitHub user token
openai_key = "..." # OpenAI key
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
pr_url = (
"..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
)
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
# Setting the configurations

View File

@ -11,26 +11,29 @@ current_dir = dirname(abspath(__file__))
global_settings = Dynaconf(
envvar_prefix=False,
merge_enabled=True,
settings_files=[join(current_dir, f) for f in [
"settings/configuration.toml",
"settings/ignore.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
"settings/custom_labels.toml",
"settings/pr_help_prompts.toml",
"settings/.secrets.toml",
"settings_prod/.secrets.toml",
]]
settings_files=[
join(current_dir, f)
for f in [
"settings/configuration.toml",
"settings/ignore.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
"settings/custom_labels.toml",
"settings/pr_help_prompts.toml",
"settings/.secrets.toml",
"settings_prod/.secrets.toml",
]
],
)

View File

@ -3,8 +3,9 @@ from starlette_context import context
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from utils.pr_agent.git_providers.bitbucket_server_provider import \
BitbucketServerProvider
from utils.pr_agent.git_providers.bitbucket_server_provider import (
BitbucketServerProvider,
)
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
from utils.pr_agent.git_providers.git_provider import GitProvider
@ -28,7 +29,9 @@ def get_git_provider():
try:
provider_id = get_settings().config.git_provider
except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e
raise ValueError(
"git_provider is a required attribute in the configuration file"
) from e
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id]

View File

@ -6,25 +6,33 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..algo.utils import (
PRDescriptionHeader,
find_line_number_of_relevant_line_in_file,
load_large_diff,
)
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import GitProvider
AZURE_DEVOPS_AVAILABLE = True
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000 - 1
try:
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
from azure.devops.connection import Connection
# noinspection PyUnresolvedReferences
from azure.devops.v7_1.git.models import (Comment, CommentThread,
GitPullRequest,
GitPullRequestIterationChanges,
GitVersionDescriptor)
from azure.devops.v7_1.git.models import (
Comment,
CommentThread,
GitPullRequest,
GitPullRequestIterationChanges,
GitVersionDescriptor,
)
# noinspection PyUnresolvedReferences
from azure.identity import DefaultAzureCredential
from msrest.authentication import BasicAuthentication
@ -33,9 +41,8 @@ except ImportError:
class AzureDevopsProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
if not AZURE_DEVOPS_AVAILABLE:
raise ImportError(
@ -67,13 +74,16 @@ class AzureDevopsProvider(GitProvider):
if not relevant_lines_start or relevant_lines_start == -1:
get_logger().warning(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end < relevant_lines_start:
get_logger().warning(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
get_logger().warning(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end > relevant_lines_start:
@ -98,30 +108,32 @@ class AzureDevopsProvider(GitProvider):
for post_parameters in post_parameters_list:
try:
comment = Comment(content=post_parameters["body"], comment_type=1)
thread = CommentThread(comments=[comment],
thread_context={
"filePath": post_parameters["path"],
"rightFileStart": {
"line": post_parameters["start_line"],
"offset": 1,
},
"rightFileEnd": {
"line": post_parameters["line"],
"offset": 1,
},
})
thread = CommentThread(
comments=[comment],
thread_context={
"filePath": post_parameters["path"],
"rightFileStart": {
"line": post_parameters["start_line"],
"offset": 1,
},
"rightFileEnd": {
"line": post_parameters["line"],
"offset": 1,
},
},
)
self.azure_devops_client.create_thread(
comment_thread=thread,
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num
pull_request_id=self.pr_num,
)
except Exception as e:
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}")
get_logger().warning(
f"Azure failed to publish code suggestion, error: {e}"
)
return True
def get_pr_description_full(self) -> str:
return self.pr.description
@ -204,9 +216,9 @@ class AzureDevopsProvider(GitProvider):
def get_files(self):
files = []
for i in self.azure_devops_client.get_pull_request_commits(
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
):
changes_obj = self.azure_devops_client.get_changes(
project=self.workspace_slug,
@ -220,7 +232,6 @@ class AzureDevopsProvider(GitProvider):
def get_diff_files(self) -> list[FilePatchInfo]:
try:
if self.diff_files:
return self.diff_files
@ -231,18 +242,20 @@ class AzureDevopsProvider(GitProvider):
iterations = self.azure_devops_client.get_pull_request_iterations(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
project=self.workspace_slug
project=self.workspace_slug,
)
changes = None
if iterations:
iteration_id = iterations[-1].id # Get the last iteration (most recent changes)
iteration_id = iterations[
-1
].id # Get the last iteration (most recent changes)
# Get changes for the iteration
changes = self.azure_devops_client.get_pull_request_iteration_changes(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
iteration_id=iteration_id,
project=self.workspace_slug
project=self.workspace_slug,
)
diff_files = []
diffs = []
@ -253,7 +266,9 @@ class AzureDevopsProvider(GitProvider):
path = item.get('path', None)
if path:
diffs.append(path)
diff_types[path] = change.additional_properties.get('changeType', 'Unknown')
diff_types[path] = change.additional_properties.get(
'changeType', 'Unknown'
)
# wrong implementation - gets all the files that were changed in any commit in the PR
# commits = self.azure_devops_client.get_pull_request_commits(
@ -284,9 +299,13 @@ class AzureDevopsProvider(GitProvider):
diffs = filter_ignored(diffs_original, 'azure')
if diffs_original != diffs:
try:
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
{"files": diffs_original, # diffs is just a list of names
"filtered_files": diffs})
get_logger().info(
f"Filtered out [ignore] files for pull request:",
extra={
"files": diffs_original, # diffs is just a list of names
"filtered_files": diffs,
},
)
except Exception:
pass
@ -311,7 +330,10 @@ class AzureDevopsProvider(GitProvider):
new_file_content_str = new_file_content_str.content
except Exception as error:
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error)
get_logger().error(
f"Failed to retrieve new file content of {file} at version {version}",
error=error,
)
# get_logger().error(
# "Failed to retrieve new file content of %s at version %s. Error: %s",
# file,
@ -325,7 +347,9 @@ class AzureDevopsProvider(GitProvider):
edit_type = EDIT_TYPE.ADDED
elif diff_types[file] == "delete":
edit_type = EDIT_TYPE.DELETED
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename`
elif (
"rename" in diff_types[file]
): # diff_type can be `rename` | `edit, rename`
edit_type = EDIT_TYPE.RENAMED
version = GitVersionDescriptor(
@ -345,17 +369,27 @@ class AzureDevopsProvider(GitProvider):
)
original_file_content_str = original_file_content_str.content
except Exception as error:
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error)
get_logger().error(
f"Failed to retrieve original file content of {file} at version {version}",
error=error,
)
original_file_content_str = ""
patch = load_large_diff(
file, new_file_content_str, original_file_content_str, show_warning=False
file,
new_file_content_str,
original_file_content_str,
show_warning=False,
).rstrip()
# count number of lines added and removed
patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
num_plus_lines = len(
[line for line in patch_lines if line.startswith('+')]
)
num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
diff_files.append(
FilePatchInfo(
@ -376,27 +410,35 @@ class AzureDevopsProvider(GitProvider):
get_logger().exception(f"Failed to get diff files, error: {e}")
return []
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None):
def publish_comment(
self, pr_comment: str, is_temporary: bool = False, thread_context=None
):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None
comment = Comment(content=pr_comment)
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5)
thread = CommentThread(
comments=[comment], thread_context=thread_context, status=5
)
thread_response = self.azure_devops_client.create_thread(
comment_thread=thread,
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
response = {
"thread_id": thread_response.id,
"comment_id": thread_response.comments[0].id,
}
if is_temporary:
self.temp_comments.append(response)
return response
def publish_description(self, pr_title: str, pr_body: str):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
usage_guide_text='<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
usage_guide_text = '<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
ind = pr_body.find(usage_guide_text)
if ind != -1:
pr_body = pr_body[:ind]
@ -409,7 +451,10 @@ class AzureDevopsProvider(GitProvider):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
trunction_message = " ... (description truncated due to length limit)"
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
pr_body = (
pr_body[: MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)]
+ trunction_message
)
get_logger().warning("PR description was truncated due to length limit")
try:
updated_pr = GitPullRequest()
@ -433,50 +478,79 @@ class AzureDevopsProvider(GitProvider):
except Exception as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}")
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
self.publish_inline_comments(
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
)
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position)
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position,
)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {}
return (
dict(
body=body,
path=path,
position=position,
absolute_position=absolute_position,
)
if subject_type == "LINE"
else {}
)
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
overall_success = True
for comment in comments:
try:
self.publish_comment(comment["body"],
thread_context={
"filePath": comment["path"],
"rightFileStart": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
"rightFileEnd": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
})
if get_settings().config.verbosity_level >= 2:
get_logger().info(
f"Published code suggestion on {self.pr_num} at {comment['path']}"
)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to publish code suggestion, error: {e}")
overall_success = False
return overall_success
def publish_inline_comments(
self, comments: list[dict], disable_fallback: bool = False
):
overall_success = True
for comment in comments:
try:
self.publish_comment(
comment["body"],
thread_context={
"filePath": comment["path"],
"rightFileStart": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
"rightFileEnd": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
},
)
if get_settings().config.verbosity_level >= 2:
get_logger().info(
f"Published code suggestion on {self.pr_num} at {comment['path']}"
)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to publish code suggestion, error: {e}")
overall_success = False
return overall_success
def get_title(self):
return self.pr.title
@ -521,7 +595,11 @@ class AzureDevopsProvider(GitProvider):
return 0
def get_issue_comments(self):
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug)
threads = self.azure_devops_client.get_threads(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
project=self.workspace_slug,
)
threads.reverse()
comment_list = []
for thread in threads:
@ -532,7 +610,9 @@ class AzureDevopsProvider(GitProvider):
comment_list.append(comment)
return comment_list
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -547,16 +627,22 @@ class AzureDevopsProvider(GitProvider):
raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL"
)
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
if (
len(path_parts) == 6
): # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
workspace_slug = path_parts[1]
repo_slug = path_parts[3]
pr_number = int(path_parts[5])
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
elif (
len(path_parts) == 5
): # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
workspace_slug = path_parts[0]
repo_slug = path_parts[2]
pr_number = int(path_parts[4])
else:
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL"
)
return workspace_slug, repo_slug, pr_number
@ -575,12 +661,16 @@ class AzureDevopsProvider(GitProvider):
# try to use azure default credentials
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.")
get_logger().info(
"No PAT found in settings, trying to use Azure Default Credentials."
)
credentials = DefaultAzureCredential()
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
auth_token = accessToken.token
except Exception as e:
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}")
get_logger().error(
f"No PAT found in settings, and Azure Default Authentication failed, error: {e}"
)
raise
credentials = BasicAuthentication("", auth_token)

View File

@ -52,13 +52,19 @@ class BitbucketProvider(GitProvider):
self.git_files = None
if pr_url:
self.set_pr(pr_url)
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"]
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href']
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"][
"comments"
]["href"]
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"][
'self'
]['href']
def get_repo_settings(self):
try:
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{self.pr.destination_branch}/.pr_agent.toml")
url = (
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{self.pr.destination_branch}/.pr_agent.toml"
)
response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found
return ""
@ -74,20 +80,27 @@ class BitbucketProvider(GitProvider):
post_parameters_list = []
for suggestion in code_suggestions:
body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
original_suggestion = suggestion.get(
'original_suggestion', None
) # needed for diff code
if original_suggestion:
try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
diff = difflib.unified_diff(
existing_code.split('\n'), improved_code.split('\n'), n=999
)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
body = re.sub(
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
)
except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
get_logger().exception(
f"Bitbucket failed to get diff code for publishing, error: {e}"
)
continue
relevant_file = suggestion["relevant_file"]
@ -129,15 +142,22 @@ class BitbucketProvider(GitProvider):
self.publish_inline_comments(post_parameters_list)
return True
except Exception as e:
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}")
get_logger().error(
f"Bitbucket failed to publish code suggestion, error: {e}"
)
return False
def publish_file_comments(self, file_comments: list) -> bool:
pass
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown',
'publish_file_comments']:
if capability in [
'get_issue_comments',
'publish_inline_comments',
'get_labels',
'gfm_markdown',
'publish_file_comments',
]:
return False
return True
@ -169,12 +189,14 @@ class BitbucketProvider(GitProvider):
names_original = [d.new.path for d in diffs_original]
names_kept = [d.new.path for d in diffs]
names_filtered = list(set(names_original) - set(names_kept))
get_logger().info(f"Filtered out [ignore] files for PR", extra={
'original_files': names_original,
'names_kept': names_kept,
'names_filtered': names_filtered
})
get_logger().info(
f"Filtered out [ignore] files for PR",
extra={
'original_files': names_original,
'names_kept': names_kept,
'names_filtered': names_filtered,
},
)
except Exception as e:
pass
@ -189,20 +211,32 @@ class BitbucketProvider(GitProvider):
for encoding in encodings_to_try:
try:
pr_patches = self.pr.diff(encoding=encoding)
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}")
get_logger().info(
f"Successfully decoded PR patch with encoding {encoding}"
)
break
except UnicodeDecodeError:
continue
if pr_patches is None:
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}")
raise ValueError(
f"Failed to decode PR patch with encodings {encodings_to_try}"
)
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()]
diff_split = [
"diff --git" + x for x in pr_patches.split("diff --git") if x.strip()
]
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs]
diff_split = [
diff_split[i]
for i in range(len(diff_split))
if diffs_original[i] in diffs
]
if len(diff_split) != len(diffs):
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts")
get_logger().error(
f"Error - failed to split the diff into {len(diffs)} parts"
)
return []
# bitbucket diff has a header for each file, we need to remove it:
# "diff --git filename
@ -213,22 +247,34 @@ class BitbucketProvider(GitProvider):
# @@ -... @@"
for i, _ in enumerate(diff_split):
diff_split_lines = diff_split[i].splitlines()
if (len(diff_split_lines) >= 6) and \
((diff_split_lines[2].startswith("---") and
diff_split_lines[3].startswith("+++") and
diff_split_lines[4].startswith("@@")) or
(diff_split_lines[3].startswith("---") and # new or deleted file
diff_split_lines[4].startswith("+++") and
diff_split_lines[5].startswith("@@"))):
if (len(diff_split_lines) >= 6) and (
(
diff_split_lines[2].startswith("---")
and diff_split_lines[3].startswith("+++")
and diff_split_lines[4].startswith("@@")
)
or (
diff_split_lines[3].startswith("---")
and diff_split_lines[4].startswith("+++") # new or deleted file
and diff_split_lines[5].startswith("@@")
)
):
diff_split[i] = "\n".join(diff_split_lines[4:])
else:
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0:
if (
diffs[i].data.get('lines_added', 0) == 0
and diffs[i].data.get('lines_removed', 0) == 0
):
diff_split[i] = ""
elif len(diff_split_lines) <= 3:
diff_split[i] = ""
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}")
get_logger().info(
f"Disregarding empty diff for file {_gef_filename(diffs[i])}"
)
else:
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}")
get_logger().warning(
f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}"
)
diff_split[i] = ""
invalid_files_names = []
@ -246,24 +292,32 @@ class BitbucketProvider(GitProvider):
if get_settings().get("bitbucket_app.avoid_full_files", False):
original_file_content_str = ""
new_file_content_str = ""
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls
elif (
counter_valid < MAX_FILES_ALLOWED_FULL // 2
): # factor 2 because bitbucket has limited API calls
if diff.old.get_data("links"):
original_file_content_str = self._get_pr_file_content(
diff.old.get_data("links")['self']['href'])
diff.old.get_data("links")['self']['href']
)
else:
original_file_content_str = ""
if diff.new.get_data("links"):
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href'])
new_file_content_str = self._get_pr_file_content(
diff.new.get_data("links")['self']['href']
)
else:
new_file_content_str = ""
else:
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
get_logger().info(
f"Bitbucket too many files in PR, will avoid loading full content for rest of files")
f"Bitbucket too many files in PR, will avoid loading full content for rest of files"
)
original_file_content_str = ""
new_file_content_str = ""
except Exception as e:
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}")
get_logger().exception(
f"Error - bitbucket failed to get file content, error: {e}"
)
original_file_content_str = ""
new_file_content_str = ""
@ -285,7 +339,9 @@ class BitbucketProvider(GitProvider):
diff_files.append(file_patch_canonic_structure)
if invalid_files_names:
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}")
get_logger().info(
f"Disregarding files with invalid extensions:\n{invalid_files_names}"
)
self.diff_files = diff_files
return diff_files
@ -296,11 +352,14 @@ class BitbucketProvider(GitProvider):
def get_comment_url(self, comment):
return comment.data['links']['html']['href']
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
):
try:
for comment in self.pr.comments():
body = comment.raw
@ -309,15 +368,20 @@ class BitbucketProvider(GitProvider):
comment_url = self.get_comment_url(comment)
if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
pr_comment_updated = pr_comment.replace(
initial_header, updated_header
)
else:
pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
get_logger().info(
f"Persistent mode - updating comment {comment_url} to latest {name} message"
)
d = {"content": {"raw": pr_comment_updated}}
response = comment._update_data(comment.put(None, data=d))
if final_update_message:
self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
)
return
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
@ -326,7 +390,9 @@ class BitbucketProvider(GitProvider):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
comment = self.pr.comment(pr_comment)
@ -355,39 +421,58 @@ class BitbucketProvider(GitProvider):
get_logger().exception(f"Failed to remove comment, error: {e}")
# function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
body = self.limit_output_characters(body, self.max_comment_length)
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position)
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position,
)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
return (
dict(body=body, path=path, position=absolute_position)
if subject_type == "LINE"
else {}
)
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
def publish_inline_comment(
self, comment: str, from_line: int, file: str, original_suggestion=None
):
comment = self.limit_output_characters(comment, self.max_comment_length)
payload = json.dumps({
"content": {
"raw": comment,
},
"inline": {
"to": from_line,
"path": file
},
})
payload = json.dumps(
{
"content": {
"raw": comment,
},
"inline": {"to": from_line, "path": file},
}
)
response = requests.request(
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
)
return response
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1:
link = f"{self.pr_url}/#L{relevant_file}"
else:
@ -402,8 +487,9 @@ class BitbucketProvider(GitProvider):
return ""
diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \
(diff_files, relevant_file, relevant_line_str)
position, absolute_position = find_line_number_of_relevant_line_in_file(
diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
@ -417,12 +503,18 @@ class BitbucketProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]):
for comment in comments:
if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
self.publish_inline_comment(
comment['body'], comment['position'], comment['path']
)
elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
self.publish_inline_comment(
comment['body'], comment['start_line'], comment['path']
)
elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
self.publish_inline_comment(
comment['body'], comment['line'], comment['path']
)
else:
get_logger().error(f"Could not publish inline comment {comment}")
@ -450,7 +542,9 @@ class BitbucketProvider(GitProvider):
"Bitbucket provider does not support issue comments yet"
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -495,8 +589,10 @@ class BitbucketProvider(GitProvider):
branch = self.pr.data["source"]["commit"]["hash"]
elif branch == self.pr.destination_branch:
branch = self.pr.data["destination"]["commit"]["hash"]
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{branch}/{file_path}")
url = (
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{branch}/{file_path}"
)
response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found
return ""
@ -505,23 +601,28 @@ class BitbucketProvider(GitProvider):
except Exception:
return ""
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None:
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/")
def create_or_update_pr_file(
self, file_path: str, branch: str, contents="", message=""
) -> None:
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
if not message:
if contents:
message = f"Update {file_path}"
else:
message = f"Create {file_path}"
files = {file_path: contents}
data = {
"message": message,
"branch": branch
}
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {}
data = {"message": message, "branch": branch}
headers = (
{'Authorization': self.headers['Authorization']}
if 'Authorization' in self.headers
else {}
)
try:
requests.request("POST", url, headers=headers, data=data, files=files)
except Exception:
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}")
get_logger().exception(
f"Failed to create empty file {file_path} in branch {branch}"
)
def _get_pr_file_content(self, remote_link: str):
try:
@ -538,16 +639,19 @@ class BitbucketProvider(GitProvider):
# bitbucket does not support labels
def publish_description(self, pr_title: str, description: str):
payload = json.dumps({
"description": description,
"title": pr_title
payload = json.dumps({"description": description, "title": pr_title})
})
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload)
response = requests.request(
"PUT",
self.bitbucket_pull_request_api_url,
headers=self.headers,
data=payload,
)
try:
if response.status_code != 200:
get_logger().info(f"Failed to update description, error code: {response.status_code}")
get_logger().info(
f"Failed to update description, error code: {response.status_code}"
)
except:
pass
return response

View File

@ -11,8 +11,7 @@ from requests.exceptions import HTTPError
from ..algo.git_patch_processing import decode_if_bytes
from ..algo.language_handler import is_valid_file
from ..algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.utils import (find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..algo.utils import find_line_number_of_relevant_line_in_file, load_large_diff
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import GitProvider
@ -20,8 +19,10 @@ from .git_provider import GitProvider
class BitbucketServerProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False,
bitbucket_client: Optional[Bitbucket] = None,
self,
pr_url: Optional[str] = None,
incremental: Optional[bool] = False,
bitbucket_client: Optional[Bitbucket] = None,
):
self.bitbucket_server_url = None
self.workspace_slug = None
@ -36,11 +37,16 @@ class BitbucketServerProvider(GitProvider):
self.bitbucket_pull_request_api_url = pr_url
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url,
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN",
None))
self.bitbucket_client = bitbucket_client or Bitbucket(
url=self.bitbucket_server_url,
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", None),
)
try:
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version'))
self.bitbucket_api_version = parse_version(
self.bitbucket_client.get("rest/api/1.0/application-properties").get(
'version'
)
)
except Exception:
self.bitbucket_api_version = None
@ -49,7 +55,12 @@ class BitbucketServerProvider(GitProvider):
def get_repo_settings(self):
try:
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch())
content = self.bitbucket_client.get_content_of_file(
self.workspace_slug,
self.repo_slug,
".pr_agent.toml",
self.get_pr_branch(),
)
return content
except Exception as e:
@ -70,20 +81,27 @@ class BitbucketServerProvider(GitProvider):
post_parameters_list = []
for suggestion in code_suggestions:
body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
original_suggestion = suggestion.get(
'original_suggestion', None
) # needed for diff code
if original_suggestion:
try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
diff = difflib.unified_diff(
existing_code.split('\n'), improved_code.split('\n'), n=999
)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
body = re.sub(
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
)
except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
get_logger().exception(
f"Bitbucket failed to get diff code for publishing, error: {e}"
)
continue
relevant_file = suggestion["relevant_file"]
relevant_lines_start = suggestion["relevant_lines_start"]
@ -134,7 +152,12 @@ class BitbucketServerProvider(GitProvider):
pass
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']:
if capability in [
'get_issue_comments',
'get_labels',
'gfm_markdown',
'publish_file_comments',
]:
return False
return True
@ -145,23 +168,28 @@ class BitbucketServerProvider(GitProvider):
def get_file(self, path: str, commit_id: str):
file_content = ""
try:
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug,
self.repo_slug,
path,
commit_id)
file_content = self.bitbucket_client.get_content_of_file(
self.workspace_slug, self.repo_slug, path, commit_id
)
except HTTPError as e:
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
return file_content
def get_files(self):
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
changes = self.bitbucket_client.get_pull_requests_changes(
self.workspace_slug, self.repo_slug, self.pr_num
)
diffstat = [change["path"]['toString'] for change in changes]
return diffstat
#gets the best common ancestor: https://git-scm.com/docs/git-merge-base
# gets the best common ancestor: https://git-scm.com/docs/git-merge-base
@staticmethod
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str:
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor}
def get_best_common_ancestor(
source_commits_list, destination_commits_list, guaranteed_common_ancestor
) -> str:
destination_commit_hashes = {
commit['id'] for commit in destination_commits_list
} | {guaranteed_common_ancestor}
for commit in source_commits_list:
for parent_commit in commit['parents']:
@ -177,37 +205,55 @@ class BitbucketServerProvider(GitProvider):
head_sha = self.pr.fromRef['latestCommit']
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"):
if (
self.bitbucket_api_version is not None
and self.bitbucket_api_version >= parse_version("8.16")
):
try:
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
except Exception as e:
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}")
get_logger().error(
f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}"
)
raise e
else:
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits(
self.workspace_slug,
self.repo_slug,
self.pr_num
))
source_commits_list = list(
self.bitbucket_client.get_pull_requests_commits(
self.workspace_slug, self.repo_slug, self.pr_num
)
)
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
base_sha = source_commits_list[-1]['parents'][0]['id']
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"):
if (
self.bitbucket_api_version is not None
and self.bitbucket_api_version >= parse_version("7.0")
):
try:
destination_commits = list(
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha,
self.pr.toRef['latestCommit']))
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha)
self.bitbucket_client.get_commits(
self.workspace_slug,
self.repo_slug,
base_sha,
self.pr.toRef['latestCommit'],
)
)
base_sha = self.get_best_common_ancestor(
source_commits_list, destination_commits, base_sha
)
except Exception as e:
get_logger().error(
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}")
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}"
)
raise e
diff_files = []
original_file_content_str = ""
new_file_content_str = ""
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
changes = self.bitbucket_client.get_pull_requests_changes(
self.workspace_slug, self.repo_slug, self.pr_num
)
for change in changes:
file_path = change['path']['toString']
if not is_valid_file(file_path.split("/")[-1]):
@ -224,17 +270,26 @@ class BitbucketServerProvider(GitProvider):
edit_type = EDIT_TYPE.DELETED
new_file_content_str = ""
original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str)
original_file_content_str = decode_if_bytes(
original_file_content_str
)
case 'RENAME':
edit_type = EDIT_TYPE.RENAMED
case _:
edit_type = EDIT_TYPE.MODIFIED
original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str)
original_file_content_str = decode_if_bytes(
original_file_content_str
)
new_file_content_str = self.get_file(file_path, head_sha)
new_file_content_str = decode_if_bytes(new_file_content_str)
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False)
patch = load_large_diff(
file_path,
new_file_content_str,
original_file_content_str,
show_warning=False,
)
diff_files.append(
FilePatchInfo(
@ -251,7 +306,9 @@ class BitbucketServerProvider(GitProvider):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if not is_temporary:
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment)
self.bitbucket_client.add_pull_request_comment(
self.workspace_slug, self.repo_slug, self.pr_num, pr_comment
)
def remove_initial_comment(self):
try:
@ -264,25 +321,37 @@ class BitbucketServerProvider(GitProvider):
pass
# function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position
absolute_position,
)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
return (
dict(body=body, path=path, position=absolute_position)
if subject_type == "LINE"
else {}
)
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
def publish_inline_comment(
self, comment: str, from_line: int, file: str, original_suggestion=None
):
payload = {
"text": comment,
"severity": "NORMAL",
@ -291,17 +360,24 @@ class BitbucketServerProvider(GitProvider):
"path": file,
"lineType": "ADDED",
"line": from_line,
"fileType": "TO"
}
"fileType": "TO",
},
}
try:
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
except Exception as e:
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}")
get_logger().error(
f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}"
)
raise e
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
else:
@ -316,8 +392,9 @@ class BitbucketServerProvider(GitProvider):
return ""
diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \
(diff_files, relevant_file, relevant_line_str)
position, absolute_position = find_line_number_of_relevant_line_in_file(
diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1:
if self.pr:
@ -325,29 +402,41 @@ class BitbucketServerProvider(GitProvider):
return link
else:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set")
get_logger().info(
f"Failed adding line link to '{relevant_file}' since PR not set"
)
else:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found")
get_logger().info(
f"Failed adding line link to '{relevant_file}' since position not found"
)
if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
return link
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}")
get_logger().info(
f"Failed adding line link to '{relevant_file}', error: {e}"
)
return ""
def publish_inline_comments(self, comments: list[dict]):
for comment in comments:
if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
elif 'start_line' in comment: # multi-line comment
self.publish_inline_comment(
comment['body'], comment['position'], comment['path']
)
elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
self.publish_inline_comment(
comment['body'], comment['start_line'], comment['path']
)
elif 'line' in comment: # single-line comment
self.publish_inline_comment(
comment['body'], comment['line'], comment['path']
)
else:
get_logger().error(f"Could not publish inline comment: {comment}")
@ -377,7 +466,9 @@ class BitbucketServerProvider(GitProvider):
"Bitbucket provider does not support issue comments yet"
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -411,14 +502,20 @@ class BitbucketServerProvider(GitProvider):
users_index = -1
if projects_index == -1 and users_index == -1:
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL")
raise ValueError(
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
)
if projects_index != -1:
path_parts = path_parts[projects_index:]
else:
path_parts = path_parts[users_index:]
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests":
if (
len(path_parts) < 6
or path_parts[2] != "repos"
or path_parts[4] != "pull-requests"
):
raise ValueError(
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
)
@ -430,19 +527,24 @@ class BitbucketServerProvider(GitProvider):
try:
pr_number = int(path_parts[5])
except ValueError as e:
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e
raise ValueError(
f"Unable to convert PR number '{path_parts[5]}' to integer"
) from e
return workspace_slug, repo_slug, pr_number
def _get_repo(self):
if self.repo is None:
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug)
self.repo = self.bitbucket_client.get_repo(
self.workspace_slug, self.repo_slug
)
return self.repo
def _get_pr(self):
try:
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug,
pull_request_id=self.pr_num)
pr = self.bitbucket_client.get_pull_request(
self.workspace_slug, self.repo_slug, pull_request_id=self.pr_num
)
return type('new_dict', (object,), pr)
except Exception as e:
get_logger().error(f"Failed to get pull request, error: {e}")
@ -460,10 +562,12 @@ class BitbucketServerProvider(GitProvider):
"version": self.pr.version,
"description": description,
"title": pr_title,
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped
"reviewers": self.pr.reviewers, # needs to be sent otherwise gets wiped
}
try:
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload)
self.bitbucket_client.update_pull_request(
self.workspace_slug, self.repo_slug, str(self.pr_num), payload
)
except Exception as e:
get_logger().error(f"Failed to update pull request, error: {e}")
raise e

View File

@ -31,7 +31,9 @@ class CodeCommitPullRequestResponse:
self.targets = []
for target in json.get("pullRequestTargets", []):
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target))
self.targets.append(
CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)
)
class CodeCommitPullRequestTarget:
"""
@ -65,7 +67,9 @@ class CodeCommitClient:
except Exception as e:
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str):
def get_differences(
self, repo_name: int, destination_commit: str, source_commit: str
):
"""
Get the differences between two commits in CodeCommit.
@ -96,17 +100,25 @@ class CodeCommitClient:
differences.extend(page.get("differences", []))
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
raise ValueError(
f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}"
) from e
raise ValueError(
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
) from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
raise ValueError(
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
) from e
output = []
for json in differences:
output.append(CodeCommitDifferencesResponse(json))
return output
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False):
def get_file(
self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False
):
"""
Retrieve a file from CodeCommit.
@ -129,16 +141,24 @@ class CodeCommitClient:
self._connect_boto_client()
try:
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path)
response = self.boto_client.get_file(
repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
raise ValueError(
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
) from e
# if the file does not exist, but is flagged as optional, then return an empty string
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
return ""
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
raise ValueError(
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
) from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
raise ValueError(
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
) from e
if "fileContent" not in response:
raise ValueError(f"File content is empty for file: {file_path}")
@ -166,10 +186,16 @@ class CodeCommitClient:
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e
raise ValueError(
f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}"
) from e
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e
raise ValueError(
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
) from e
raise ValueError(
f"CodeCommit cannot retrieve PR: {pr_number}: boto client error"
) from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
@ -200,22 +226,37 @@ class CodeCommitClient:
self._connect_boto_client()
try:
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title)
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body)
self.boto_client.update_pull_request_title(
pullRequestId=str(pr_number), title=pr_title
)
self.boto_client.update_pull_request_description(
pullRequestId=str(pr_number), description=pr_body
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidTitleException':
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
raise ValueError(f"Invalid description for PR number: {pr_number}") from e
raise ValueError(
f"Invalid description for PR number: {pr_number}"
) from e
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
raise ValueError(f"Boto3 client error calling publish_description") from e
except Exception as e:
raise ValueError(f"Error calling publish_description") from e
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None):
def publish_comment(
self,
repo_name: str,
pr_number: int,
destination_commit: str,
source_commit: str,
comment: str,
annotation_file: str = None,
annotation_line: int = None,
):
"""
Publish a comment to a pull request
@ -272,6 +313,8 @@ class CodeCommitClient:
raise ValueError(f"Repository does not exist: {repo_name}") from e
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e
raise ValueError(
f"Boto3 client error calling post_comment_for_pull_request"
) from e
except Exception as e:
raise ValueError(f"Error calling post_comment_for_pull_request") from e

View File

@ -55,7 +55,9 @@ class CodeCommitProvider(GitProvider):
This class implements the GitProvider interface for AWS CodeCommit repositories.
"""
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
self.codecommit_client = CodeCommitClient()
self.aws_client = None
self.repo_name = None
@ -76,7 +78,7 @@ class CodeCommitProvider(GitProvider):
"create_inline_comment",
"publish_inline_comments",
"get_labels",
"gfm_markdown"
"gfm_markdown",
]:
return False
return True
@ -91,13 +93,19 @@ class CodeCommitProvider(GitProvider):
return self.git_files
self.git_files = []
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit)
differences = self.codecommit_client.get_differences(
self.repo_name, self.pr.destination_commit, self.pr.source_commit
)
for item in differences:
self.git_files.append(CodeCommitFile(item.before_blob_path,
item.before_blob_id,
item.after_blob_path,
item.after_blob_id,
CodeCommitProvider._get_edit_type(item.change_type)))
self.git_files.append(
CodeCommitFile(
item.before_blob_path,
item.before_blob_id,
item.after_blob_path,
item.after_blob_id,
CodeCommitProvider._get_edit_type(item.change_type),
)
)
return self.git_files
def get_diff_files(self) -> list[FilePatchInfo]:
@ -121,21 +129,28 @@ class CodeCommitProvider(GitProvider):
if diff_item.a_blob_id is not None:
patch_filename = diff_item.a_path
original_file_content_str = self.codecommit_client.get_file(
self.repo_name, diff_item.a_path, self.pr.destination_commit)
self.repo_name, diff_item.a_path, self.pr.destination_commit
)
if isinstance(original_file_content_str, (bytes, bytearray)):
original_file_content_str = original_file_content_str.decode("utf-8")
original_file_content_str = original_file_content_str.decode(
"utf-8"
)
else:
original_file_content_str = ""
if diff_item.b_blob_id is not None:
patch_filename = diff_item.b_path
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit)
new_file_content_str = self.codecommit_client.get_file(
self.repo_name, diff_item.b_path, self.pr.source_commit
)
if isinstance(new_file_content_str, (bytes, bytearray)):
new_file_content_str = new_file_content_str.decode("utf-8")
else:
new_file_content_str = ""
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str)
patch = load_large_diff(
patch_filename, new_file_content_str, original_file_content_str
)
# Store the diffs as a list of FilePatchInfo objects
info = FilePatchInfo(
@ -164,7 +179,9 @@ class CodeCommitProvider(GitProvider):
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e
raise ValueError(
f"CodeCommit Cannot publish description for PR: {self.pr_num}"
) from e
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary:
@ -183,19 +200,28 @@ class CodeCommitProvider(GitProvider):
comment=pr_comment,
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e
raise ValueError(
f"CodeCommit Cannot publish comment for PR: {self.pr_num}"
) from e
def publish_code_suggestions(self, code_suggestions: list) -> bool:
counter = 1
for suggestion in code_suggestions:
# Verify that each suggestion has the required keys
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]):
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys")
if not all(
key in suggestion
for key in ["body", "relevant_file", "relevant_lines_start"]
):
get_logger().warning(
f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys"
)
continue
# Publish the code suggestion to CodeCommit
try:
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}")
get_logger().debug(
f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}"
)
self.codecommit_client.publish_comment(
repo_name=self.repo_name,
pr_number=self.pr_num,
@ -206,7 +232,9 @@ class CodeCommitProvider(GitProvider):
annotation_line=suggestion["relevant_lines_start"],
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e
raise ValueError(
f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}"
) from e
counter += 1
@ -227,12 +255,22 @@ class CodeCommitProvider(GitProvider):
def remove_comment(self, comment):
return "" # not implemented yet
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
raise NotImplementedError(
"CodeCommit provider does not support publishing inline comments yet"
)
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
raise NotImplementedError(
"CodeCommit provider does not support publishing inline comments yet"
)
def get_title(self):
return self.pr.title
@ -257,7 +295,7 @@ class CodeCommitProvider(GitProvider):
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
"""
commit_files = self.get_files()
filenames = [ item.filename for item in commit_files ]
filenames = [item.filename for item in commit_files]
extensions = CodeCommitProvider._get_file_extensions(filenames)
# Calculate the percentage of each file extension in the PR
@ -270,7 +308,9 @@ class CodeCommitProvider(GitProvider):
# We build that language->extension dictionary here in main_extensions_flat.
main_extensions_flat = {}
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
for language, extensions in language_extension_map.items():
for ext in extensions:
main_extensions_flat[ext] = language
@ -292,14 +332,20 @@ class CodeCommitProvider(GitProvider):
return -1 # not implemented yet
def get_issue_comments(self):
raise NotImplementedError("CodeCommit provider does not support issue comments yet")
raise NotImplementedError(
"CodeCommit provider does not support issue comments yet"
)
def get_repo_settings(self):
# a local ".pr_agent.toml" settings file is optional
settings_filename = ".pr_agent.toml"
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True)
return self.codecommit_client.get_file(
self.repo_name, settings_filename, self.pr.source_commit, optional=True
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
get_logger().info("CodeCommit provider does not support eyes reaction yet")
return True
@ -323,7 +369,9 @@ class CodeCommitProvider(GitProvider):
parsed_url = urlparse(pr_url)
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}")
raise ValueError(
f"The provided URL is not a valid CodeCommit URL: {pr_url}"
)
path_parts = parsed_url.path.strip("/").split("/")
@ -334,14 +382,18 @@ class CodeCommitProvider(GitProvider):
or path_parts[2] != "repositories"
or path_parts[4] != "pull-requests"
):
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}")
raise ValueError(
f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}"
)
repo_name = path_parts[3]
try:
pr_number = int(path_parts[5])
except ValueError as e:
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e
raise ValueError(
f"Unable to convert PR number to integer: '{path_parts[5]}'"
) from e
return repo_name, pr_number
@ -359,7 +411,12 @@ class CodeCommitProvider(GitProvider):
Returns:
- bool: True if the hostname is valid, False otherwise.
"""
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None
return (
re.match(
r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname
)
is not None
)
def _get_pr(self):
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)

View File

@ -38,10 +38,7 @@ def clone(url, directory):
def fetch(url, refspec, cwd):
get_logger().info("Fetching %s %s", url, refspec)
stdout = _call(
'git', 'fetch', '--depth', '2', url, refspec,
cwd=cwd
)
stdout = _call('git', 'fetch', '--depth', '2', url, refspec, cwd=cwd)
get_logger().info(stdout)
@ -75,10 +72,13 @@ def add_comment(url: urllib3.util.Url, refspec, message):
message = "'" + message.replace("'", "'\"'\"'") + "'"
return _call(
"ssh",
"-p", str(url.port),
"-p",
str(url.port),
f"{url.auth}@{url.host}",
"gerrit", "review",
"--message", message,
"gerrit",
"review",
"--message",
message,
# "--code-review", score,
f"{patchset},{changenum}",
)
@ -88,19 +88,23 @@ def list_comments(url: urllib3.util.Url, refspec):
*_, patchset, _ = refspec.rsplit("/")
stdout = _call(
"ssh",
"-p", str(url.port),
"-p",
str(url.port),
f"{url.auth}@{url.host}",
"gerrit", "query",
"gerrit",
"query",
"--comments",
"--current-patch-set", patchset,
"--format", "JSON",
"--current-patch-set",
patchset,
"--format",
"JSON",
)
change_set, *_ = stdout.splitlines()
return json.loads(change_set)["currentPatchSet"]["comments"]
def prepare_repo(url: urllib3.util.Url, project, refspec):
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
repo_url = f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}"
directory = pathlib.Path(mkdtemp())
clone(repo_url, directory),
@ -114,18 +118,18 @@ def adopt_to_gerrit_message(message):
buf = []
for line in lines:
# remove markdown formatting
line = (line.replace("*", "")
.replace("``", "`")
.replace("<details>", "")
.replace("</details>", "")
.replace("<summary>", "")
.replace("</summary>", ""))
line = (
line.replace("*", "")
.replace("``", "`")
.replace("<details>", "")
.replace("</details>", "")
.replace("<summary>", "")
.replace("</summary>", "")
)
line = line.strip()
if line.startswith('#'):
buf.append("\n" +
line.replace('#', '').removesuffix(":").strip() +
":")
buf.append("\n" + line.replace('#', '').removesuffix(":").strip() + ":")
continue
elif line.startswith('-'):
buf.append(line.removeprefix('-').strip())
@ -136,12 +140,9 @@ def adopt_to_gerrit_message(message):
def add_suggestion(src_filename, context: str, start, end: int):
with (
NamedTemporaryFile("w", delete=False) as tmp,
open(src_filename, "r") as src
):
with NamedTemporaryFile("w", delete=False) as tmp, open(src_filename, "r") as src:
lines = src.readlines()
tmp.writelines(lines[:start - 1])
tmp.writelines(lines[: start - 1])
if context:
tmp.write(context)
tmp.writelines(lines[end:])
@ -151,10 +152,8 @@ def add_suggestion(src_filename, context: str, start, end: int):
def upload_patch(patch, path):
patch_server_endpoint = get_settings().get(
'gerrit.patch_server_endpoint')
patch_server_token = get_settings().get(
'gerrit.patch_server_token')
patch_server_endpoint = get_settings().get('gerrit.patch_server_endpoint')
patch_server_token = get_settings().get('gerrit.patch_server_token')
response = requests.post(
patch_server_endpoint,
@ -165,7 +164,7 @@ def upload_patch(patch, path):
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {patch_server_token}",
}
},
)
response.raise_for_status()
patch_server_endpoint = patch_server_endpoint.rstrip("/")
@ -173,7 +172,6 @@ def upload_patch(patch, path):
class GerritProvider(GitProvider):
def __init__(self, key: str, incremental=False):
self.project, self.refspec = key.split(':')
assert self.project, "Project name is required"
@ -188,9 +186,7 @@ class GerritProvider(GitProvider):
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
)
self.repo_path = prepare_repo(
self.parsed_url, self.project, self.refspec
)
self.repo_path = prepare_repo(self.parsed_url, self.project, self.refspec)
self.repo = Repo(self.repo_path)
assert self.repo
self.pr_url = base_url
@ -210,15 +206,18 @@ class GerritProvider(GitProvider):
def get_pr_labels(self, update=False):
raise NotImplementedError(
'Getting labels is not implemented for the gerrit provider')
'Getting labels is not implemented for the gerrit provider'
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
raise NotImplementedError(
'Adding reactions is not implemented for the gerrit provider')
'Adding reactions is not implemented for the gerrit provider'
)
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
raise NotImplementedError(
'Removing reactions is not implemented for the gerrit provider')
'Removing reactions is not implemented for the gerrit provider'
)
def get_commit_messages(self):
return [self.repo.head.commit.message]
@ -235,20 +234,21 @@ class GerritProvider(GitProvider):
diffs = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit
create_patch=True,
R=True
R=True,
)
diff_files = []
for diff_item in diffs:
if diff_item.a_blob is not None:
original_file_content_str = (
diff_item.a_blob.data_stream.read().decode('utf-8')
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
'utf-8'
)
else:
original_file_content_str = "" # empty file
if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read(). \
decode('utf-8')
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
'utf-8'
)
else:
new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED
@ -267,7 +267,7 @@ class GerritProvider(GitProvider):
edit_type=edit_type,
old_filename=None
if diff_item.a_path == diff_item.b_path
else diff_item.a_path
else diff_item.a_path,
)
)
self.diff_files = diff_files
@ -275,8 +275,7 @@ class GerritProvider(GitProvider):
def get_files(self):
diff_index = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit
R=True
self.repo.head.commit.parents[0], R=True # previous commit
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
@ -288,16 +287,22 @@ class GerritProvider(GitProvider):
prioritisation.
"""
# Get all files in repository
filepaths = [Path(item.path) for item in
self.repo.tree().traverse() if item.type == 'blob']
filepaths = [
Path(item.path)
for item in self.repo.tree().traverse()
if item.type == 'blob'
]
# Identify language by file extension and count
lang_count = Counter(
ext.lstrip('.') for filepath in filepaths for ext in
[filepath.suffix.lower()])
ext.lstrip('.')
for filepath in filepaths
for ext in [filepath.suffix.lower()]
)
# Convert counts to percentages
total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count
in lang_count.items()}
lang_percentage = {
lang: count / total_files * 100 for lang, count in lang_count.items()
}
return lang_percentage
def get_pr_description_full(self):
@ -312,7 +317,7 @@ class GerritProvider(GitProvider):
'create_inline_comment',
'publish_inline_comments',
'get_labels',
'gfm_markdown'
'gfm_markdown',
]:
return False
return True
@ -331,14 +336,9 @@ class GerritProvider(GitProvider):
if is_code_context:
context.append(line)
else:
description.append(
line.replace('*', '')
)
description.append(line.replace('*', ''))
return (
'\n'.join(description),
'\n'.join(context) + '\n' if context else ''
)
return ('\n'.join(description), '\n'.join(context) + '\n' if context else '')
def publish_code_suggestions(self, code_suggestions: list):
msg = []
@ -372,15 +372,19 @@ class GerritProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit '
'provider')
'Publishing inline comments is not implemented for the gerrit ' 'provider'
)
def publish_inline_comment(self, body: str, relevant_file: str,
relevant_line_in_file: str, original_suggestion=None):
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit '
'provider')
'Publishing inline comments is not implemented for the gerrit ' 'provider'
)
def publish_labels(self, labels):
# Not applicable to the local git provider,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from typing import Optional
@ -9,6 +10,7 @@ from utils.pr_agent.log import get_logger
MAX_FILES_ALLOWED_FULL = 50
class GitProvider(ABC):
@abstractmethod
def is_supported(self, capability: str) -> bool:
@ -61,11 +63,18 @@ class GitProvider(ABC):
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
pass
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple:
def get_pr_description(
self, full: bool = True, split_changes_walkthrough=False
) -> str or tuple:
from utils.pr_agent.algo.utils import clip_tokens
from utils.pr_agent.config_loader import get_settings
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() if full else self.get_user_description()
max_tokens_description = get_settings().get(
"CONFIG.MAX_DESCRIPTION_TOKENS", None
)
description = (
self.get_pr_description_full() if full else self.get_user_description()
)
if split_changes_walkthrough:
description, files = process_description(description)
if max_tokens_description:
@ -94,7 +103,9 @@ class GitProvider(ABC):
# return nothing (empty string) because it means there is no user description
user_description_header = "### **user description**"
if user_description_header not in description_lowercase:
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description")
get_logger().info(
f"Existing description was generated by the pr-agent, but it doesn't contain a user description"
)
return ""
# otherwise, extract the original user description from the existing pr-agent description and return it
@ -103,9 +114,11 @@ class GitProvider(ABC):
# the 'user description' is in the beginning. extract and return it
possible_headers = self._possible_headers()
start_position = description_lowercase.find(user_description_header) + len(user_description_header)
start_position = description_lowercase.find(user_description_header) + len(
user_description_header
)
end_position = len(description)
for header in possible_headers: # try to clip at the next header
for header in possible_headers: # try to clip at the next header
if header != user_description_header and header in description_lowercase:
end_position = min(end_position, description_lowercase.find(header))
if end_position != len(description) and end_position > start_position:
@ -115,20 +128,34 @@ class GitProvider(ABC):
else:
original_user_description = description.split("___")[0].strip()
if original_user_description.lower().startswith(user_description_header):
original_user_description = original_user_description[len(user_description_header):].strip()
original_user_description = original_user_description[
len(user_description_header) :
].strip()
get_logger().info(f"Extracted user description from existing description",
description=original_user_description)
get_logger().info(
f"Extracted user description from existing description",
description=original_user_description,
)
self.user_description = original_user_description
return original_user_description
def _possible_headers(self):
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**",
"### **labels**", "### 🤖 generated by pr agent")
return (
"### **user description**",
"### **pr type**",
"### **pr description**",
"### **pr labels**",
"### **type**",
"### **description**",
"### **labels**",
"### 🤖 generated by pr agent",
)
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
possible_headers = self._possible_headers()
return any(description_lowercase.startswith(header) for header in possible_headers)
return any(
description_lowercase.startswith(header) for header in possible_headers
)
@abstractmethod
def get_repo_settings(self):
@ -140,10 +167,17 @@ class GitProvider(ABC):
def get_pr_id(self):
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
return ""
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
def get_lines_link_original_file(
self, filepath: str, component_range: Range
) -> str:
return ""
#### comments operations ####
@ -151,18 +185,24 @@ class GitProvider(ABC):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
):
self.publish_comment(pr_comment)
def publish_persistent_comment_full(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
def publish_persistent_comment_full(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
):
try:
prev_comments = list(self.get_issue_comments())
for comment in prev_comments:
@ -171,29 +211,46 @@ class GitProvider(ABC):
comment_url = self.get_comment_url(comment)
if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
pr_comment_updated = pr_comment.replace(
initial_header, updated_header
)
else:
pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
get_logger().info(
f"Persistent mode - updating comment {comment_url} to latest {name} message"
)
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
self.edit_comment(comment, pr_comment_updated)
if final_update_message:
self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
)
return
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
pass
self.publish_comment(pr_comment)
@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
pass
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
raise NotImplementedError("This git provider does not support creating inline comments yet")
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
raise NotImplementedError(
"This git provider does not support creating inline comments yet"
)
@abstractmethod
def publish_inline_comments(self, comments: list[dict]):
@ -227,7 +284,9 @@ class GitProvider(ABC):
pass
@abstractmethod
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
pass
@abstractmethod
@ -284,16 +343,23 @@ def get_main_pr_language(languages, files) -> str:
if not file:
continue
if isinstance(file, str):
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file)
file = FilePatchInfo(
base_file=None, head_file=None, patch=None, filename=file
)
extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
try:
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
if (
top_language in language_extension_map
and most_common_extension in language_extension_map[top_language]
):
main_language_str = top_language
else:
for language, extensions in language_extension_map.items():
@ -332,8 +398,6 @@ def get_main_pr_language(languages, files) -> str:
return main_language_str
class IncrementalPR:
def __init__(self, is_incremental: bool = False):
self.is_incremental = is_incremental

View File

@ -18,14 +18,23 @@ from ..algo.file_filter import filter_ignored
from ..algo.git_patch_processing import extract_hunk_headers
from ..algo.language_handler import is_valid_file
from ..algo.types import EDIT_TYPE
from ..algo.utils import (PRReviewHeader, Range, clip_tokens,
find_line_number_of_relevant_line_in_file,
load_large_diff, set_file_languages)
from ..algo.utils import (
PRReviewHeader,
Range,
clip_tokens,
find_line_number_of_relevant_line_in_file,
load_large_diff,
set_file_languages,
)
from ..config_loader import get_settings
from ..log import get_logger
from ..servers.utils import RateLimitExceeded
from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider,
IncrementalPR)
from .git_provider import (
MAX_FILES_ALLOWED_FULL,
FilePatchInfo,
GitProvider,
IncrementalPR,
)
class GithubProvider(GitProvider):
@ -36,8 +45,14 @@ class GithubProvider(GitProvider):
except Exception:
self.installation_id = None
self.max_comment_chars = 65000
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com"
self.base_url = (
get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
) # "https://api.github.com"
self.base_url_html = (
self.base_url.split("api/")[0].rstrip("/")
if "api/" in self.base_url
else "https://github.com"
)
self.github_client = self._get_github_client()
self.repo = None
self.pr_num = None
@ -50,7 +65,9 @@ class GithubProvider(GitProvider):
self.set_pr(pr_url)
self.pr_commits = list(self.pr.get_commits())
self.last_commit_id = self.pr_commits[-1]
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
self.pr_url = (
self.get_pr_url()
) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
else:
self.pr_commits = None
@ -80,10 +97,14 @@ class GithubProvider(GitProvider):
# Get all files changed during the commit range
for commit in self.incremental.commits_range:
if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"):
if commit.commit.message.startswith(
f"Merge branch '{self._get_repo().default_branch}'"
):
get_logger().info(f"Skipping merge commit {commit.commit.message}")
continue
self.unreviewed_files_set.update({file.filename: file for file in commit.files})
self.unreviewed_files_set.update(
{file.filename: file for file in commit.files}
)
else:
get_logger().info("No previous review found, will review the entire PR")
self.incremental.is_incremental = False
@ -98,7 +119,11 @@ class GithubProvider(GitProvider):
else:
self.incremental.last_seen_commit = self.pr_commits[index]
break
return self.pr_commits[first_new_commit_index:] if first_new_commit_index is not None else []
return (
self.pr_commits[first_new_commit_index:]
if first_new_commit_index is not None
else []
)
def get_previous_review(self, *, full: bool, incremental: bool):
if not (full or incremental):
@ -121,7 +146,7 @@ class GithubProvider(GitProvider):
git_files = context.get("git_files", None)
if git_files:
return git_files
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
context["git_files"] = self.git_files
return self.git_files
except Exception:
@ -138,8 +163,13 @@ class GithubProvider(GitProvider):
except Exception as e:
return -1
@retry(exceptions=RateLimitExceeded,
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
@retry(
exceptions=RateLimitExceeded,
tries=get_settings().github.ratelimit_retries,
delay=2,
backoff=2,
jitter=(1, 3),
)
def get_diff_files(self) -> list[FilePatchInfo]:
"""
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
@ -167,9 +197,10 @@ class GithubProvider(GitProvider):
try:
names_original = [file.filename for file in files_original]
names_new = [file.filename for file in files]
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
{"files": names_original,
"filtered_files": names_new})
get_logger().info(
f"Filtered out [ignore] files for pull request:",
extra={"files": names_original, "filtered_files": names_new},
)
except Exception:
pass
@ -184,14 +215,17 @@ class GithubProvider(GitProvider):
repo = self.repo_obj
pr = self.pr
try:
compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub
compare = repo.compare(
pr.base.sha, pr.head.sha
) # communication with GitHub
merge_base_commit = compare.merge_base_commit
except Exception as e:
get_logger().error(f"Failed to get merge base commit: {e}")
merge_base_commit = pr.base
if merge_base_commit.sha != pr.base.sha:
get_logger().info(
f"Using merge base commit {merge_base_commit.sha} instead of base commit ")
f"Using merge base commit {merge_base_commit.sha} instead of base commit "
)
counter_valid = 0
for file in files:
@ -207,29 +241,48 @@ class GithubProvider(GitProvider):
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1
avoid_load = False
if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental:
if (
counter_valid >= MAX_FILES_ALLOWED_FULL
and patch
and not self.incremental.is_incremental
):
avoid_load = True
if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
get_logger().info(
f"Too many files in PR, will avoid loading full content for rest of files"
)
if avoid_load:
new_file_content_str = ""
else:
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
new_file_content_str = self._get_pr_file_content(
file, self.pr.head.sha
) # communication with GitHub
if self.incremental.is_incremental and self.unreviewed_files_set:
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
original_file_content_str = self._get_pr_file_content(
file, self.incremental.last_seen_commit_sha
)
patch = load_large_diff(
file.filename,
new_file_content_str,
original_file_content_str,
)
self.unreviewed_files_set[file.filename] = patch
else:
if avoid_load:
original_file_content_str = ""
else:
original_file_content_str = self._get_pr_file_content(file, merge_base_commit.sha)
original_file_content_str = self._get_pr_file_content(
file, merge_base_commit.sha
)
# original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
if not patch:
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
patch = load_large_diff(
file.filename,
new_file_content_str,
original_file_content_str,
)
if file.status == 'added':
edit_type = EDIT_TYPE.ADDED
@ -249,16 +302,27 @@ class GithubProvider(GitProvider):
num_minus_lines = file.deletions
else:
patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
num_plus_lines = len(
[line for line in patch_lines if line.startswith('+')]
)
num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch,
file.filename, edit_type=edit_type,
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines,)
file_patch_canonical_structure = FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch,
file.filename,
edit_type=edit_type,
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines,
)
diff_files.append(file_patch_canonical_structure)
if invalid_files_names:
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
get_logger().info(
f"Filtered out files with invalid extensions: {invalid_files_names}"
)
self.diff_files = diff_files
try:
@ -269,8 +333,10 @@ class GithubProvider(GitProvider):
return diff_files
except Exception as e:
get_logger().error(f"Failing to get diff files: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Failing to get diff files: {e}",
artifact={"traceback": traceback.format_exc()},
)
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str):
@ -282,16 +348,23 @@ class GithubProvider(GitProvider):
def get_comment_url(self, comment) -> str:
return comment.html_url
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
):
self.publish_persistent_comment_full(
pr_comment, initial_header, update_header, name, final_update_message
)
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
response = self.pr.create_issue_comment(pr_comment)
@ -303,42 +376,68 @@ class GithubProvider(GitProvider):
self.pr.comments_list.append(response)
return response
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
body = self.limit_output_characters(body, self.max_comment_chars)
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
self.publish_inline_comments(
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
)
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
body = self.limit_output_characters(body, self.max_comment_chars)
position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files,
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position)
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.diff_files,
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position,
)
if position == -1:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=position) if subject_type == "LINE" else {}
return (
dict(body=body, path=path, position=position)
if subject_type == "LINE"
else {}
)
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
def publish_inline_comments(
self, comments: list[dict], disable_fallback: bool = False
):
try:
# publish all comments in a single message
self.pr.create_review(commit=self.last_commit_id, comments=comments)
except Exception as e:
get_logger().info(f"Initially failed to publish inline comments as committable")
get_logger().info(
f"Initially failed to publish inline comments as committable"
)
if (getattr(e, "status", None) == 422 and not disable_fallback):
if getattr(e, "status", None) == 422 and not disable_fallback:
pass # continue to try _publish_inline_comments_fallback_with_verification
else:
raise e # will end up with publishing the comments one by one
raise e # will end up with publishing the comments one by one
try:
self._publish_inline_comments_fallback_with_verification(comments)
except Exception as e:
get_logger().error(f"Failed to publish inline code comments fallback, error: {e}")
get_logger().error(
f"Failed to publish inline code comments fallback, error: {e}"
)
raise e
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
@ -352,20 +451,27 @@ class GithubProvider(GitProvider):
# publish as a group the verified comments
if verified_comments:
try:
self.pr.create_review(commit=self.last_commit_id, comments=verified_comments)
self.pr.create_review(
commit=self.last_commit_id, comments=verified_comments
)
except:
pass
# try to publish one by one the invalid comments as a one-line code comment
if invalid_comments and get_settings().github.try_fix_invalid_inline_comments:
fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments(
[comment for comment, _ in invalid_comments])
[comment for comment, _ in invalid_comments]
)
for comment in fixed_comments_as_one_liner:
try:
self.publish_inline_comments([comment], disable_fallback=True)
get_logger().info(f"Published invalid comment as a single line comment: {comment}")
get_logger().info(
f"Published invalid comment as a single line comment: {comment}"
)
except:
get_logger().error(f"Failed to publish invalid comment as a single line comment: {comment}")
get_logger().error(
f"Failed to publish invalid comment as a single line comment: {comment}"
)
def _verify_code_comment(self, comment: dict):
is_verified = False
@ -374,7 +480,8 @@ class GithubProvider(GitProvider):
# event ="" # By leaving this blank, you set the review action state to PENDING
input = dict(commit_id=self.last_commit_id.sha, comments=[comment])
headers, data = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.pr.url}/reviews", input=input)
"POST", f"{self.pr.url}/reviews", input=input
)
pending_review_id = data["id"]
is_verified = True
except Exception as err:
@ -383,12 +490,16 @@ class GithubProvider(GitProvider):
e = err
if pending_review_id is not None:
try:
self.pr._requester.requestJsonAndCheck("DELETE", f"{self.pr.url}/reviews/{pending_review_id}")
self.pr._requester.requestJsonAndCheck(
"DELETE", f"{self.pr.url}/reviews/{pending_review_id}"
)
except Exception:
pass
return is_verified, e
def _verify_code_comments(self, comments: list[dict]) -> tuple[list[dict], list[tuple[dict, Exception]]]:
def _verify_code_comments(
self, comments: list[dict]
) -> tuple[list[dict], list[tuple[dict, Exception]]]:
"""Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments"""
verified_comments = []
invalid_comments = []
@ -401,17 +512,22 @@ class GithubProvider(GitProvider):
invalid_comments.append((comment, e))
return verified_comments, invalid_comments
def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list[dict]:
def _try_fix_invalid_inline_comments(
self, invalid_comments: list[dict]
) -> list[dict]:
"""
Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line.
Return only comments that have been modified in some way.
This is a best-effort attempt to fix invalid comments, and should be verified accordingly.
"""
import copy
fixed_comments = []
for comment in invalid_comments:
try:
fixed_comment = copy.deepcopy(comment) # avoid modifying the original comment dict for later logging
fixed_comment = copy.deepcopy(
comment
) # avoid modifying the original comment dict for later logging
if "```suggestion" in comment["body"]:
fixed_comment["body"] = comment["body"].split("```suggestion")[0]
if "start_line" in comment:
@ -432,7 +548,9 @@ class GithubProvider(GitProvider):
"""
post_parameters_list = []
code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions)
code_suggestions_validated = self.validate_comments_inside_hunks(
code_suggestions
)
for suggestion in code_suggestions_validated:
body = suggestion['body']
@ -442,13 +560,16 @@ class GithubProvider(GitProvider):
if not relevant_lines_start or relevant_lines_start == -1:
get_logger().exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end < relevant_lines_start:
get_logger().exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
get_logger().exception(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end > relevant_lines_start:
@ -484,17 +605,21 @@ class GithubProvider(GitProvider):
# Log as warning for permission-related issues (usually due to polling)
get_logger().warning(
"Failed to edit github comment due to permission restrictions",
artifact={"error": e})
artifact={"error": e},
)
else:
get_logger().exception(f"Failed to edit github comment", artifact={"error": e})
get_logger().exception(
f"Failed to edit github comment", artifact={"error": e}
)
def edit_comment_from_comment_id(self, comment_id: int, body: str):
try:
# self.pr.get_issue_comment(comment_id).edit(body)
body = self.limit_output_characters(body, self.max_comment_chars)
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
input={"body": body}
"PATCH",
f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
input={"body": body},
)
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
@ -504,8 +629,9 @@ class GithubProvider(GitProvider):
# self.pr.get_issue_comment(comment_id).edit(body)
body = self.limit_output_characters(body, self.max_comment_chars)
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
input={"body": body}
"POST",
f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
input={"body": body},
)
except Exception as e:
get_logger().exception(f"Failed to reply comment, error: {e}")
@ -516,7 +642,7 @@ class GithubProvider(GitProvider):
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}"
)
return data_patch.get("body","")
return data_patch.get("body", "")
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
return None
@ -528,7 +654,9 @@ class GithubProvider(GitProvider):
)
for comment in file_comments:
comment['commit_id'] = self.last_commit_id.sha
comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars)
comment['body'] = self.limit_output_characters(
comment['body'], self.max_comment_chars
)
found = False
for existing_comment in existing_comments:
@ -536,13 +664,23 @@ class GithubProvider(GitProvider):
our_app_name = get_settings().get("GITHUB.APP_NAME", "")
same_comment_creator = False
if self.deployment_type == 'app':
same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower()
same_comment_creator = (
our_app_name.lower()
in existing_comment['user']['login'].lower()
)
elif self.deployment_type == 'user':
same_comment_creator = self.github_user_id == existing_comment['user']['login']
if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator:
same_comment_creator = (
self.github_user_id == existing_comment['user']['login']
)
if (
existing_comment['subject_type'] == 'file'
and comment['path'] == existing_comment['path']
and same_comment_creator
):
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']}
"PATCH",
f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}",
input={"body": comment['body']},
)
found = True
break
@ -600,7 +738,9 @@ class GithubProvider(GitProvider):
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications")
raise ValueError(
"Deployment mode must be set to 'user' to get notifications"
)
notifications = self.github_client.get_user().get_notifications(since=since)
return notifications
@ -621,13 +761,16 @@ class GithubProvider(GitProvider):
def get_workspace_name(self):
return self.repo.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
if disable_eyes:
return None
try:
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
input={"content": "eyes"}
"POST",
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
input={"content": "eyes"},
)
return data_patch.get("id", None)
except Exception as e:
@ -639,7 +782,7 @@ class GithubProvider(GitProvider):
# self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
headers, data_patch = self.pr._requester.requestJsonAndCheck(
"DELETE",
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}"
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}",
)
return True
except Exception as e:
@ -655,7 +798,9 @@ class GithubProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/')
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
if len(path_parts) < 5 or path_parts[3] != 'pulls':
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
raise ValueError(
"The provided URL does not appear to be a GitHub PR URL"
)
repo_name = '/'.join(path_parts[1:3])
try:
pr_number = int(path_parts[4])
@ -683,7 +828,9 @@ class GithubProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/')
if 'api.github.com' in parsed_url.netloc:
if len(path_parts) < 5 or path_parts[3] != 'issues':
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL")
raise ValueError(
"The provided URL does not appear to be a GitHub ISSUE URL"
)
repo_name = '/'.join(path_parts[1:3])
try:
issue_number = int(path_parts[4])
@ -710,11 +857,18 @@ class GithubProvider(GitProvider):
private_key = get_settings().github.private_key
app_id = get_settings().github.app_id
except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
raise ValueError(
"GitHub app ID and private key are required when using GitHub app deployment"
) from e
if not self.installation_id:
raise ValueError("GitHub app installation ID is required when using GitHub app deployment")
auth = AppAuthentication(app_id=app_id, private_key=private_key,
installation_id=self.installation_id)
raise ValueError(
"GitHub app installation ID is required when using GitHub app deployment"
)
auth = AppAuthentication(
app_id=app_id,
private_key=private_key,
installation_id=self.installation_id,
)
return Github(app_auth=auth, base_url=self.base_url)
if deployment_type == 'user':
@ -723,19 +877,21 @@ class GithubProvider(GitProvider):
except AttributeError as e:
raise ValueError(
"GitHub token is required when using user deployment. See: "
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source"
) from e
return Github(auth=Auth.Token(token), base_url=self.base_url)
def _get_repo(self):
if hasattr(self, 'repo_obj') and \
hasattr(self.repo_obj, 'full_name') and \
self.repo_obj.full_name == self.repo:
if (
hasattr(self, 'repo_obj')
and hasattr(self.repo_obj, 'full_name')
and self.repo_obj.full_name == self.repo
):
return self.repo_obj
else:
self.repo_obj = self.github_client.get_repo(self.repo)
return self.repo_obj
def _get_pr(self):
return self._get_repo().get_pull(self.pr_num)
@ -755,9 +911,9 @@ class GithubProvider(GitProvider):
) -> None:
try:
file_obj = self._get_repo().get_contents(file_path, ref=branch)
sha1=file_obj.sha
sha1 = file_obj.sha
except Exception:
sha1=""
sha1 = ""
self.repo_obj.update_file(
path=file_path,
message=message,
@ -771,9 +927,14 @@ class GithubProvider(GitProvider):
def publish_labels(self, pr_types):
try:
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
"Enhancement": "bfd4f2", "Documentation": "d4c5f9",
"Other": "d1bcf9"}
label_color_map = {
"Bug fix": "1d76db",
"Tests": "e99695",
"Bug fix with tests": "c5def5",
"Enhancement": "bfd4f2",
"Documentation": "d4c5f9",
"Other": "d1bcf9",
}
post_parameters = []
for p in pr_types:
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
@ -787,11 +948,12 @@ class GithubProvider(GitProvider):
def get_pr_labels(self, update=False):
try:
if not update:
labels =self.pr.labels
labels = self.pr.labels
return [label.name for label in labels]
else: # obtain the latest labels. Maybe they changed while the AI was running
else: # obtain the latest labels. Maybe they changed while the AI was running
headers, labels = self.pr._requester.requestJsonAndCheck(
"GET", f"{self.pr.issue_url}/labels")
"GET", f"{self.pr.issue_url}/labels"
)
return [label['name'] for label in labels]
except Exception as e:
@ -813,7 +975,9 @@ class GithubProvider(GitProvider):
try:
commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
commit_messages_str = "\n".join(
[f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]
)
except Exception:
commit_messages_str = ""
if max_tokens:
@ -822,13 +986,16 @@ class GithubProvider(GitProvider):
def generate_link_to_relevant_line_number(self, suggestion) -> str:
try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n')
relevant_file = (
suggestion['relevant_file'].strip('`').strip("'").strip('\n')
)
relevant_line_str = suggestion['relevant_line'].strip('\n')
if not relevant_line_str:
return ""
position, absolute_position = find_line_number_of_relevant_line_in_file \
(self.diff_files, relevant_file, relevant_line_str)
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1:
# # link to right file only
@ -844,7 +1011,12 @@ class GithubProvider(GitProvider):
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest()
if relevant_line_start == -1:
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}"
@ -854,7 +1026,9 @@ class GithubProvider(GitProvider):
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
return link
def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str:
def get_lines_link_original_file(
self, filepath: str, component_range: Range
) -> str:
"""
Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
@ -876,8 +1050,10 @@ class GithubProvider(GitProvider):
line_end = component_range.line_end + 1
# link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
# f"#L{line_start}-L{line_end}")
link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
f"#L{line_start}-L{line_end}")
link = (
f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
f"#L{line_start}-L{line_end}"
)
return link
@ -909,8 +1085,9 @@ class GithubProvider(GitProvider):
}}
}}
"""
response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql",
input={"query": query})
response_tuple = self.github_client._Github__requester.requestJson(
"POST", "/graphql", input={"query": query}
)
# Extract the JSON response from the tuple and parses it
if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
@ -919,8 +1096,12 @@ class GithubProvider(GitProvider):
get_logger().error(f"Unexpected response format: {response_tuple}")
return sub_issues
issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id")
issue_id = (
response_json.get("data", {})
.get("repository", {})
.get("issue", {})
.get("id")
)
if not issue_id:
get_logger().warning(f"Issue ID not found for {issue_url}")
@ -940,22 +1121,42 @@ class GithubProvider(GitProvider):
}}
}}
"""
sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={
"query": sub_issues_query})
sub_issues_response_tuple = (
self.github_client._Github__requester.requestJson(
"POST", "/graphql", input={"query": sub_issues_query}
)
)
# Extract the JSON response from the tuple and parses it
if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3:
if (
isinstance(sub_issues_response_tuple, tuple)
and len(sub_issues_response_tuple) == 3
):
sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
else:
get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple})
get_logger().error(
"Unexpected sub-issues response format",
artifact={"response": sub_issues_response_tuple},
)
return sub_issues
if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"):
if (
not sub_issues_response_json.get("data", {})
.get("node", {})
.get("subIssues")
):
get_logger().error("Invalid sub-issues response structure")
return sub_issues
nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", [])
get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes})
nodes = (
sub_issues_response_json.get("data", {})
.get("node", {})
.get("subIssues", {})
.get("nodes", [])
)
get_logger().info(
f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}
)
for sub_issue in nodes:
if "url" in sub_issue:
@ -977,7 +1178,7 @@ class GithubProvider(GitProvider):
return False
def calc_pr_statistics(self, pull_request_data: dict):
return {}
return {}
def validate_comments_inside_hunks(self, code_suggestions):
"""
@ -986,7 +1187,8 @@ class GithubProvider(GitProvider):
code_suggestions_copy = copy.deepcopy(code_suggestions)
diff_files = self.get_diff_files()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
diff_files = set_file_languages(diff_files)
@ -995,7 +1197,6 @@ class GithubProvider(GitProvider):
relevant_file_path = suggestion['relevant_file']
for file in diff_files:
if file.filename == relevant_file_path:
# generate on-demand the patches range for the relevant file
patch_str = file.patch
if not hasattr(file, 'patches_range'):
@ -1006,14 +1207,30 @@ class GithubProvider(GitProvider):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
file.patches_range.append({'start': start2, 'end': start2 + size2 - 1})
(
section_header,
size1,
size2,
start1,
start2,
) = extract_hunk_headers(match)
file.patches_range.append(
{'start': start2, 'end': start2 + size2 - 1}
)
patches_range = file.patches_range
comment_start_line = suggestion.get('relevant_lines_start', None)
comment_start_line = suggestion.get(
'relevant_lines_start', None
)
comment_end_line = suggestion.get('relevant_lines_end', None)
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
if not comment_start_line or not comment_end_line or not original_suggestion:
original_suggestion = suggestion.get(
'original_suggestion', None
) # needed for diff code
if (
not comment_start_line
or not comment_end_line
or not original_suggestion
):
continue
# check if the comment is inside a valid hunk
@ -1037,30 +1254,57 @@ class GithubProvider(GitProvider):
patch_range_min = patch_range
min_distance = min(min_distance, d)
if not is_valid_hunk:
if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk
if (
min_distance < 10
): # 10 lines - a reasonable distance to consider the comment inside the hunk
# make the suggestion non-committable, yet multi line
suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start'])
suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end'])
suggestion['relevant_lines_start'] = max(
suggestion['relevant_lines_start'],
patch_range_min['start'],
)
suggestion['relevant_lines_end'] = min(
suggestion['relevant_lines_end'],
patch_range_min['end'],
)
body = suggestion['body'].strip()
# present new diff code in collapsible
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
existing_code = (
original_suggestion['existing_code'].rstrip() + "\n"
)
improved_code = (
original_suggestion['improved_code'].rstrip() + "\n"
)
diff = difflib.unified_diff(
existing_code.split('\n'),
improved_code.split('\n'),
n=999,
)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
patch = "\n".join(patch_orig.splitlines()[5:]).strip(
'\n'
)
diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
body = re.sub(
r'```suggestion.*?```',
diff_code,
body,
flags=re.DOTALL,
)
body += "\n\n</details>"
suggestion['body'] = body
get_logger().info(f"Comment was moved to a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
get_logger().info(
f"Comment was moved to a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
)
else:
get_logger().error(f"Comment is not inside a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
get_logger().error(
f"Comment is not inside a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
)
except Exception as e:
get_logger().error(f"Failed to process patch for committable comment, error: {e}")
get_logger().error(
f"Failed to process patch for committable comment, error: {e}"
)
return code_suggestions_copy

View File

@ -10,9 +10,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import (clip_tokens,
find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..algo.utils import (
clip_tokens,
find_line_number_of_relevant_line_in_file,
load_large_diff,
)
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
@ -20,22 +22,26 @@ from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
class DiffNotFoundError(Exception):
"""Raised when the diff for a merge request cannot be found."""
pass
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
class GitLabProvider(GitProvider):
def __init__(
self,
merge_request_url: Optional[str] = None,
incremental: Optional[bool] = False,
):
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file")
self.gitlab_url = gitlab_url
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab(
url=gitlab_url,
oauth_token=gitlab_access_token
)
raise ValueError(
"GitLab personal access token is not set in the config file"
)
self.gl = gitlab.Gitlab(url=gitlab_url, oauth_token=gitlab_access_token)
self.max_comment_chars = 65000
self.id_project = None
self.id_mr = None
@ -46,12 +52,17 @@ class GitLabProvider(GitProvider):
self.pr_url = merge_request_url
self._set_merge_request(merge_request_url)
self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
self.incremental = incremental
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments',
'publish_file_comments']: # gfm_markdown is supported in gitlab !
if capability in [
'get_issue_comments',
'create_inline_comment',
'publish_inline_comments',
'publish_file_comments',
]: # gfm_markdown is supported in gitlab !
return False
return True
@ -67,12 +78,17 @@ class GitLabProvider(GitProvider):
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
except IndexError as e:
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e
raise DiffNotFoundError(
f"Could not get diff for merge request {self.id_mr}"
) from e
def get_pr_file_content(self, file_path: str, branch: str) -> str:
try:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
return (
self.gl.projects.get(self.id_project)
.files.get(file_path, branch)
.decode()
)
except GitlabGetError:
# In case of file creation the method returns GitlabGetError (404 file not found).
# In this case we return an empty string for the diff.
@ -98,10 +114,13 @@ class GitLabProvider(GitProvider):
try:
names_original = [diff['new_path'] for diff in diffs_original]
names_filtered = [diff['new_path'] for diff in diffs]
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={
'original_files': names_original,
'filtered_files': names_filtered
})
get_logger().info(
f"Filtered out [ignore] files for merge request {self.id_mr}",
extra={
'original_files': names_original,
'filtered_files': names_filtered,
},
)
except Exception as e:
pass
@ -116,22 +135,31 @@ class GitLabProvider(GitProvider):
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])
original_file_content_str = self.get_pr_file_content(
diff['old_path'], self.mr.diff_refs['base_sha']
)
new_file_content_str = self.get_pr_file_content(
diff['new_path'], self.mr.diff_refs['head_sha']
)
else:
if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
get_logger().info(
f"Too many files in PR, will avoid loading full content for rest of files"
)
original_file_content_str = ''
new_file_content_str = ''
try:
if isinstance(original_file_content_str, bytes):
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
original_file_content_str = bytes.decode(
original_file_content_str, 'utf-8'
)
if isinstance(new_file_content_str, bytes):
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError:
get_logger().warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}"
)
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
@ -144,30 +172,43 @@ class GitLabProvider(GitProvider):
filename = diff['new_path']
patch = diff['diff']
if not patch:
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)
patch = load_large_diff(
filename, new_file_content_str, original_file_content_str
)
# count number of lines added and removed
patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str,
patch=patch,
filename=filename,
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'],
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines, ))
FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch=patch,
filename=filename,
edit_type=edit_type,
old_filename=None
if diff['old_path'] == diff['new_path']
else diff['old_path'],
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines,
)
)
if invalid_files_names:
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
get_logger().info(
f"Filtered out files with invalid extensions: {invalid_files_names}"
)
self.diff_files = diff_files
return diff_files
def get_files(self) -> list:
if not self.git_files:
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
self.git_files = [
change['new_path'] for change in self.mr.changes()['changes']
]
return self.git_files
def publish_description(self, pr_title: str, pr_body: str):
@ -176,7 +217,9 @@ class GitLabProvider(GitProvider):
self.mr.description = pr_body
self.mr.save()
except Exception as e:
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}")
get_logger().exception(
f"Could not update merge request {self.id_mr} description: {e}"
)
def get_latest_commit_url(self):
return self.mr.commits().next().web_url
@ -184,16 +227,23 @@ class GitLabProvider(GitProvider):
def get_comment_url(self, comment):
return f"{self.mr.web_url}#note_{comment.id}"
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
):
self.publish_persistent_comment_full(
pr_comment, initial_header, update_header, name, final_update_message
)
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}")
get_logger().debug(
f"Skipping publish_comment for temporary comment: {mr_comment}"
)
return None
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
comment = self.mr.notes.create({'body': mr_comment})
@ -203,7 +253,7 @@ class GitLabProvider(GitProvider):
def edit_comment(self, comment, body: str):
body = self.limit_output_characters(body, self.max_comment_chars)
self.mr.notes.update(comment.id,{'body': body} )
self.mr.notes.update(comment.id, {'body': body})
def edit_comment_from_comment_id(self, comment_id: int, body: str):
body = self.limit_output_characters(body, self.max_comment_chars)
@ -216,39 +266,87 @@ class GitLabProvider(GitProvider):
discussion = self.mr.discussions.get(comment_id)
discussion.notes.create({'body': body})
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
body = self.limit_output_characters(body, self.max_comment_chars)
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no, original_suggestion)
(
edit_type,
found,
source_line_no,
target_file,
target_line_no,
) = self.search_line(relevant_file, relevant_line_in_file)
self.send_inline_comment(
body,
edit_type,
found,
relevant_file,
relevant_line_in_file,
source_line_no,
target_file,
target_line_no,
original_suggestion,
)
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None):
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
def create_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
raise NotImplementedError(
"Gitlab provider does not support creating inline comments yet"
)
def create_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
raise NotImplementedError(
"Gitlab provider does not support publishing inline comments yet"
)
def get_comment_body_from_comment_id(self, comment_id: int):
comment = self.mr.notes.get(comment_id).body
return comment
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str,
relevant_line_in_file: str,
source_line_no: int, target_file: str, target_line_no: int,
original_suggestion=None) -> None:
def send_inline_comment(
self,
body: str,
edit_type: str,
found: bool,
relevant_file: str,
relevant_line_in_file: str,
source_line_no: int,
target_file: str,
target_line_no: int,
original_suggestion=None,
) -> None:
if not found:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
else:
# in order to have exact sha's we have to find correct diff for this change
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
if diff is None:
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}")
pos_obj = {'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha}
raise DiffNotFoundError(
f"Could not get diff for merge request {self.id_mr}"
)
pos_obj = {
'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename
if target_file.old_filename
else target_file.filename,
'base_sha': diff.base_commit_sha,
'start_sha': diff.start_commit_sha,
'head_sha': diff.head_commit_sha,
}
if edit_type == 'deletion':
pos_obj['old_line'] = source_line_no - 1
elif edit_type == 'addition':
@ -256,15 +354,21 @@ class GitLabProvider(GitProvider):
else:
pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = source_line_no - 1
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}")
get_logger().debug(
f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}"
)
try:
self.mr.discussions.create({'body': body, 'position': pos_obj})
except Exception as e:
try:
# fallback - create a general note on the file in the MR
if 'suggestion_orig_location' in original_suggestion:
line_start = original_suggestion['suggestion_orig_location']['start_line']
line_end = original_suggestion['suggestion_orig_location']['end_line']
line_start = original_suggestion['suggestion_orig_location'][
'start_line'
]
line_end = original_suggestion['suggestion_orig_location'][
'end_line'
]
old_code_snippet = original_suggestion['prev_code_snippet']
new_code_snippet = original_suggestion['new_code_snippet']
content = original_suggestion['suggestion_summary']
@ -287,36 +391,49 @@ class GitLabProvider(GitProvider):
else:
language = ''
link = self.get_line_link(relevant_file, line_start, line_end)
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
body_fallback +=f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
body_fallback = (
f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
)
body_fallback += f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
body_fallback+="</details>\n\n"
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'),
new_code_snippet.split('\n'), n=999)
body_fallback += "</details>\n\n"
diff_patch = difflib.unified_diff(
old_code_snippet.split('\n'),
new_code_snippet.split('\n'),
n=999,
)
patch_orig = "\n".join(diff_patch)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
body_fallback += diff_code
# Create a general note on the file in the MR
self.mr.notes.create({
'body': body_fallback,
'position': {
'base_sha': diff.base_commit_sha,
'start_sha': diff.start_commit_sha,
'head_sha': diff.head_commit_sha,
'position_type': 'text',
'file_path': f'{target_file.filename}',
self.mr.notes.create(
{
'body': body_fallback,
'position': {
'base_sha': diff.base_commit_sha,
'start_sha': diff.start_commit_sha,
'head_sha': diff.head_commit_sha,
'position_type': 'text',
'file_path': f'{target_file.filename}',
},
}
})
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}")
)
get_logger().debug(
f"Created fallback comment in MR {self.id_mr} with position {pos_obj}"
)
# get_logger().debug(
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
except Exception as e:
get_logger().exception(f"Failed to create comment in MR {self.id_mr}")
get_logger().exception(
f"Failed to create comment in MR {self.id_mr}"
)
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]:
def get_relevant_diff(
self, relevant_file: str, relevant_line_in_file: str
) -> Optional[dict]:
changes = self.mr.changes() # Retrieve the changes for the merge request once
if not changes:
get_logger().error('No changes found for the merge request.')
@ -327,10 +444,14 @@ class GitLabProvider(GitProvider):
return None
for diff in all_diffs:
for change in changes['changes']:
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']:
if (
change['new_path'] == relevant_file
and relevant_line_in_file in change['diff']
):
return diff
get_logger().debug(
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.'
)
return self.last_diff # fallback to last_diff if no relevant diff is found
def publish_code_suggestions(self, code_suggestions: list) -> bool:
@ -352,7 +473,7 @@ class GitLabProvider(GitProvider):
if file.filename == relevant_file:
target_file = file
break
range = relevant_lines_end - relevant_lines_start # no need to add 1
range = relevant_lines_end - relevant_lines_start # no need to add 1
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
lines = target_file.head_file.splitlines()
relevant_line_in_file = lines[relevant_lines_start - 1]
@ -365,10 +486,21 @@ class GitLabProvider(GitProvider):
found = True
edit_type = 'addition'
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no, original_suggestion)
self.send_inline_comment(
body,
edit_type,
found,
relevant_file,
relevant_line_in_file,
source_line_no,
target_file,
target_line_no,
original_suggestion,
)
except Exception as e:
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
get_logger().exception(
f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}"
)
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
return True
@ -382,8 +514,13 @@ class GitLabProvider(GitProvider):
edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.get_diff_files():
if file.filename == relevant_file:
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
relevant_line_in_file)
(
edit_type,
found,
source_line_no,
target_file,
target_line_no,
) = self.find_in_file(file, relevant_line_in_file)
return edit_type, found, source_line_no, target_file, target_line_no
def find_in_file(self, file, relevant_line_in_file):
@ -414,7 +551,10 @@ class GitLabProvider(GitProvider):
found = True
edit_type = self.get_edit_type(line)
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line:
elif (
relevant_line_in_file[0] == '+'
and relevant_line_in_file[1:].lstrip() in line
):
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
found = True
@ -470,7 +610,11 @@ class GitLabProvider(GitProvider):
def get_repo_settings(self):
try:
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode()
contents = (
self.gl.projects.get(self.id_project)
.files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch)
.decode()
)
return contents
except Exception:
return ""
@ -478,7 +622,9 @@ class GitLabProvider(GitProvider):
def get_workspace_name(self):
return self.id_project.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -489,7 +635,9 @@ class GitLabProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/')
if 'merge_requests' not in path_parts:
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
raise ValueError(
"The provided URL does not appear to be a GitLab merge request URL"
)
mr_index = path_parts.index('merge_requests')
# Ensure there is an ID after 'merge_requests'
@ -541,8 +689,15 @@ class GitLabProvider(GitProvider):
"""
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
commit_messages_list = [
commit['message'] for commit in self.mr.commits()._list
]
commit_messages_str = "\n".join(
[
f"{i + 1}. {message}"
for i, message in enumerate(commit_messages_list)
]
)
except Exception:
commit_messages_str = ""
if max_tokens:
@ -556,7 +711,12 @@ class GitLabProvider(GitProvider):
except:
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1:
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
elif relevant_line_end:
@ -565,7 +725,6 @@ class GitLabProvider(GitProvider):
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
return link
def generate_link_to_relevant_line_number(self, suggestion) -> str:
try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
@ -573,8 +732,9 @@ class GitLabProvider(GitProvider):
if not relevant_line_str:
return ""
position, absolute_position = find_line_number_of_relevant_line_in_file \
(self.diff_files, relevant_file, relevant_line_str)
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1:
# link to right file only

View File

@ -39,10 +39,16 @@ class LocalGitProvider(GitProvider):
self._prepare_repo()
self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = get_settings().get('local.description_path') \
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = get_settings().get('local.review_path') \
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
self.description_path = (
get_settings().get('local.description_path')
if get_settings().get('local.description_path') is not None
else self.repo_path / 'description.md'
)
self.review_path = (
get_settings().get('local.review_path')
if get_settings().get('local.review_path') is not None
else self.repo_path / 'review.md'
)
# inline code comments are not supported for local git repositories
get_settings().pr_reviewer.inline_code_comments = False
@ -52,30 +58,43 @@ class LocalGitProvider(GitProvider):
"""
get_logger().debug('Preparing repository for PR-mimic generation...')
if self.repo.is_dirty():
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
raise ValueError(
'The repository is not in a clean state. Please commit or stash pending changes.'
)
if self.target_branch_name not in self.repo.heads:
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels',
'gfm_markdown']:
if capability in [
'get_issue_comments',
'create_inline_comment',
'publish_inline_comments',
'get_labels',
'gfm_markdown',
]:
return False
return True
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
self.repo.merge_base(
self.repo.head, self.repo.branches[self.target_branch_name]
),
create_patch=True,
R=True
R=True,
)
diff_files = []
for diff_item in diffs:
if diff_item.a_blob is not None:
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
'utf-8'
)
else:
original_file_content_str = "" # empty file
if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
'utf-8'
)
else:
new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED
@ -86,13 +105,16 @@ class LocalGitProvider(GitProvider):
elif diff_item.renamed_file:
edit_type = EDIT_TYPE.RENAMED
diff_files.append(
FilePatchInfo(original_file_content_str,
new_file_content_str,
diff_item.diff.decode('utf-8'),
diff_item.b_path,
edit_type=edit_type,
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
)
FilePatchInfo(
original_file_content_str,
new_file_content_str,
diff_item.diff.decode('utf-8'),
diff_item.b_path,
edit_type=edit_type,
old_filename=None
if diff_item.a_path == diff_item.b_path
else diff_item.a_path,
)
)
self.diff_files = diff_files
return diff_files
@ -102,8 +124,10 @@ class LocalGitProvider(GitProvider):
Returns a list of files with changes in the diff.
"""
diff_index = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
R=True
self.repo.merge_base(
self.repo.head, self.repo.branches[self.target_branch_name]
),
R=True,
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
@ -119,18 +143,37 @@ class LocalGitProvider(GitProvider):
# Write the string to the file
file.write(pr_comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
raise NotImplementedError(
'Publishing inline comments is not implemented for the local git provider'
)
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
raise NotImplementedError(
'Publishing inline comments is not implemented for the local git provider'
)
def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_code_suggestion(
self,
body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int,
):
raise NotImplementedError(
'Publishing code suggestions is not implemented for the local git provider'
)
def publish_code_suggestions(self, code_suggestions: list) -> bool:
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
raise NotImplementedError(
'Publishing code suggestions is not implemented for the local git provider'
)
def publish_labels(self, labels):
pass # Not applicable to the local git provider, but required by the interface
@ -158,19 +201,31 @@ class LocalGitProvider(GitProvider):
Calculate percentage of languages in repository. Used for hunk prioritisation.
"""
# Get all files in repository
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
filepaths = [
Path(item.path)
for item in self.repo.tree().traverse()
if item.type == 'blob'
]
# Identify language by file extension and count
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
lang_count = Counter(
ext.lstrip('.')
for filepath in filepaths
for ext in [filepath.suffix.lower()]
)
# Convert counts to percentages
total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
lang_percentage = {
lang: count / total_files * 100 for lang, count in lang_count.items()
}
return lang_percentage
def get_pr_branch(self):
return self.repo.head
def get_user_id(self):
return -1 # Not used anywhere for the local provider, but required by the interface
return (
-1
) # Not used anywhere for the local provider, but required by the interface
def get_pr_description_full(self):
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
@ -186,7 +241,11 @@ class LocalGitProvider(GitProvider):
return self.head_branch_name
def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
raise NotImplementedError(
'Getting issue comments is not implemented for the local git provider'
)
def get_pr_labels(self, update=False):
raise NotImplementedError('Getting labels is not implemented for the local git provider')
raise NotImplementedError(
'Getting labels is not implemented for the local git provider'
)

View File

@ -6,7 +6,7 @@ from dynaconf import Dynaconf
from starlette_context import context
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (get_git_provider_with_context)
from utils.pr_agent.git_providers import get_git_provider_with_context
from utils.pr_agent.log import get_logger
@ -20,7 +20,9 @@ def apply_repo_settings(pr_url):
except Exception:
repo_settings = None
pass
if repo_settings is None: # None is different from "", which is a valid value
if (
repo_settings is None
): # None is different from "", which is a valid value
repo_settings = git_provider.get_repo_settings()
try:
context["repo_settings"] = repo_settings
@ -36,15 +38,25 @@ def apply_repo_settings(pr_url):
os.write(fd, repo_settings)
new_settings = Dynaconf(settings_files=[repo_settings_file])
for section, contents in new_settings.as_dict().items():
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {}))
section_dict = copy.deepcopy(
get_settings().as_dict().get(section, {})
)
for key, value in contents.items():
section_dict[key] = value
get_settings().unset(section)
get_settings().set(section, section_dict, merge=False)
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}")
get_logger().info(
f"Applying repo settings:\n{new_settings.as_dict()}"
)
except Exception as e:
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}")
error_local = {'error': str(e), 'settings': repo_settings, 'category': category}
get_logger().warning(
f"Failed to apply repo {category} settings, error: {str(e)}"
)
error_local = {
'error': str(e),
'settings': repo_settings,
'category': category,
}
if error_local:
handle_configurations_errors([error_local], git_provider)
@ -55,7 +67,10 @@ def apply_repo_settings(pr_url):
try:
os.remove(repo_settings_file)
except Exception as e:
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
get_logger().error(
f"Failed to remove temporary settings file {repo_settings_file}",
e,
)
# enable switching models with a short definition
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
@ -79,13 +94,18 @@ def handle_configurations_errors(config_errors, git_provider):
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
else:
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body})
get_logger().warning(
f"Sending a 'configuration error' comment to the PR",
artifact={'body': body},
)
# git_provider.publish_comment(body)
if hasattr(git_provider, 'publish_persistent_comment'):
git_provider.publish_persistent_comment(body,
initial_header=header,
update_header=False,
final_update_message=False)
git_provider.publish_persistent_comment(
body,
initial_header=header,
update_header=False,
final_update_message=False,
)
else:
git_provider.publish_comment(body)
except Exception as e:

View File

@ -1,10 +1,9 @@
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.identity_providers.default_identity_provider import \
DefaultIdentityProvider
from utils.pr_agent.identity_providers.default_identity_provider import (
DefaultIdentityProvider,
)
_IDENTITY_PROVIDERS = {
'default': DefaultIdentityProvider
}
_IDENTITY_PROVIDERS = {'default': DefaultIdentityProvider}
def get_identity_provider():

View File

@ -1,5 +1,7 @@
from utils.pr_agent.identity_providers.identity_provider import (Eligibility,
IdentityProvider)
from utils.pr_agent.identity_providers.identity_provider import (
Eligibility,
IdentityProvider,
)
class DefaultIdentityProvider(IdentityProvider):

View File

@ -30,7 +30,9 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
if type(level) is not int:
level = logging.INFO
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app
if (
fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0"
): # better debugging github_app
logger.remove(None)
logger.add(
sys.stdout,
@ -40,7 +42,7 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
colorize=False,
serialize=True,
)
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
logger.remove(None)
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)

View File

@ -8,10 +8,14 @@ def get_secret_provider():
provider_id = get_settings().config.secret_provider
if provider_id == 'google_cloud_storage':
try:
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \
GoogleCloudStorageSecretProvider
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import (
GoogleCloudStorageSecretProvider,
)
return GoogleCloudStorageSecretProvider()
except Exception as e:
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
raise ValueError(
f"Failed to initialize google_cloud_storage secret provider {provider_id}"
) from e
else:
raise ValueError("Unknown SECRET_PROVIDER")

View File

@ -9,12 +9,15 @@ from utils.pr_agent.secret_providers.secret_provider import SecretProvider
class GoogleCloudStorageSecretProvider(SecretProvider):
def __init__(self):
try:
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage.
service_account))
self.client = storage.Client.from_service_account_info(
ujson.loads(get_settings().google_cloud_storage.service_account)
)
self.bucket_name = get_settings().google_cloud_storage.bucket_name
self.bucket = self.client.bucket(self.bucket_name)
except Exception as e:
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}")
get_logger().error(
f"Failed to initialize Google Cloud Storage Secret Provider: {e}"
)
raise e
def get_secret(self, secret_name: str) -> str:
@ -22,7 +25,9 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
blob = self.bucket.blob(secret_name)
return blob.download_as_string()
except Exception as e:
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}")
get_logger().warning(
f"Failed to get secret {secret_name} from Google Cloud Storage: {e}"
)
return ""
def store_secret(self, secret_name: str, secret_value: str):
@ -30,5 +35,7 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
blob = self.bucket.blob(secret_name)
blob.upload_from_string(secret_value)
except Exception as e:
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}")
get_logger().error(
f"Failed to store secret {secret_name} in Google Cloud Storage: {e}"
)
raise e

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
class SecretProvider(ABC):
@abstractmethod
def get_secret(self, secret_name: str) -> str:
pass

View File

@ -33,6 +33,7 @@ azure_devops_server = get_settings().get("azure_devops_server")
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
def handle_request(
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
):
@ -52,20 +53,27 @@ def handle_request(
# currently only basic auth is supported with azure webhooks
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect username or password.',
headers={'WWW-Authenticate': 'Basic'},
)
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect username or password.',
headers={'WWW-Authenticate': 'Basic'},
)
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict):
async def _perform_commands_azure(
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict
):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
if (
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
**log_context,
)
return
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
get_settings().set("config.is_auto_command", True)
@ -92,22 +100,38 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
actions = []
if data["eventType"] == "git.pullrequest.created":
# API V1 (latest)
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git"))
pr_url = unquote(
data["resource"]["_links"]["web"]["href"].replace(
"_apis/git/repositories", "_git"
)
)
log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
return
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
elif (
data["eventType"] == "ms.vss-code.git-pullrequest-comment-event"
and "content" in data["resource"]["comment"]
):
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
if(data["resourceVersion"] == "2.0"):
if data["resourceVersion"] == "2.0":
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}')
pr_url = unquote(
f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}'
)
actions = [data["resource"]["comment"]["content"]]
else:
# API V1 not supported as it does not contain the PR URL
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
return (
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps(
{
"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"
}
),
),
)
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
@ -132,17 +156,21 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
content=json.dumps({"message": "Internal server error"}),
)
return JSONResponse(
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"})
status_code=status.HTTP_202_ACCEPTED,
content=jsonable_encoder({"message": "webhook triggered successfully"}),
)
@router.get("/")
async def root():
return {"status": "ok"}
def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == "__main__":
start()

View File

@ -27,7 +27,9 @@ from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
secret_provider = (
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
)
async def get_bearer_token(shared_secret: str, client_key: str):
@ -44,12 +46,12 @@ async def get_bearer_token(shared_secret: str, client_key: str):
"exp": now + 240,
"qsh": qsh,
"sub": client_key,
}
}
token = jwt.encode(payload, shared_secret, algorithm="HS256")
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
headers = {
'Authorization': f'JWT {token}',
'Content-Type': 'application/x-www-form-urlencoded'
'Content-Type': 'application/x-www-form-urlencoded',
}
response = requests.request("POST", url, headers=headers, data=payload)
bearer_token = response.json()["access_token"]
@ -58,6 +60,7 @@ async def get_bearer_token(shared_secret: str, client_key: str):
get_logger().error(f"Failed to get bearer token: {e}")
raise e
@router.get("/")
async def handle_manifest(request: Request, response: Response):
cur_dir = os.path.dirname(os.path.abspath(__file__))
@ -66,7 +69,9 @@ async def handle_manifest(request: Request, response: Response):
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
except:
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue")
get_logger().error(
"Failed to replace api_key in Bitbucket manifest, trying to continue"
)
manifest_obj = json.loads(manifest)
return JSONResponse(manifest_obj)
@ -83,10 +88,16 @@ def _get_username(data):
return ""
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict):
async def _perform_commands_bitbucket(
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
if (
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
)
return
if data.get("event", "") == "pullrequest:created":
if not should_process_pr_logic(data):
@ -132,7 +143,9 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
)
return False
# logic to ignore PRs with specific titles
@ -140,20 +153,34 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
if ignore_pr_title_re and any(
re.search(regex, title) for regex in ignore_pr_title_re
):
get_logger().info(
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
)
return False
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
if (ignore_pr_source_branches or ignore_pr_target_branches):
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
ignore_pr_source_branches = get_settings().get(
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
)
ignore_pr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
if ignore_pr_source_branches or ignore_pr_target_branches:
if any(
re.search(regex, source_branch) for regex in ignore_pr_source_branches
):
get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
)
return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
if any(
re.search(regex, target_branch) for regex in ignore_pr_target_branches
):
get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
)
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -195,7 +222,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
client_key = claims["iss"]
secrets = json.loads(secret_provider.get_secret(client_key))
shared_secret = secrets["shared_secret"]
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
jwt.decode(
input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]
)
bearer_token = await get_bearer_token(shared_secret, client_key)
context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings)
@ -208,28 +237,41 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
if pr_url:
with get_logger().contextualize(**log_context):
apply_repo_settings(pr_url)
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
if (
get_identity_provider().verify_eligibility(
"bitbucket", sender_id, pr_url
)
is not Eligibility.NOT_ELIGIBLE
):
if get_settings().get("bitbucket_app.pr_commands"):
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data)
await _perform_commands_bitbucket(
"pr_commands", PRAgent(), pr_url, log_context, data
)
elif event == "pullrequest:comment_created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url
log_context["event"] = "comment"
comment_body = data["data"]["comment"]["content"]["raw"]
with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
if (
get_identity_provider().verify_eligibility(
"bitbucket", sender_id, pr_url
)
is not Eligibility.NOT_ELIGIBLE
):
await agent.handle_request(pr_url, comment_body)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)
return "OK"
@router.get("/webhook")
async def handle_github_webhooks(request: Request, response: Response):
return "Webhook server online!"
@router.post("/installed")
async def handle_installed_webhooks(request: Request, response: Response):
try:
@ -240,15 +282,13 @@ async def handle_installed_webhooks(request: Request, response: Response):
shared_secret = data["sharedSecret"]
client_key = data["clientKey"]
username = data["principal"]["username"]
secrets = {
"shared_secret": shared_secret,
"client_key": client_key
}
secrets = {"shared_secret": shared_secret, "client_key": client_key}
secret_provider.store_secret(username, json.dumps(secrets))
except Exception as e:
get_logger().error(f"Failed to register user: {e}")
return JSONResponse({"error": "Unable to register user"}, status_code=500)
@router.post("/uninstalled")
async def handle_uninstalled_webhooks(request: Request, response: Response):
get_logger().info("handle_uninstalled_webhooks")

View File

@ -40,10 +40,12 @@ def handle_request(
background_tasks.add_task(inner)
@router.post("/")
async def redirect_to_webhook():
return RedirectResponse(url="/webhook")
@router.post("/webhook")
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "bitbucket_server"}
@ -55,7 +57,8 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
body_bytes = await request.body()
if body_bytes.decode('utf-8') == '{"test": true}':
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"})
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "connection test successful"}),
)
signature_header = request.headers.get("x-hub-signature", None)
verify_signature(body_bytes, webhook_secret, signature_header)
@ -73,11 +76,18 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
if data["eventKey"] == "pr:opened":
apply_repo_settings(pr_url)
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context)
if (
get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {pr_url}",
**log_context,
)
return
get_settings().set("config.is_auto_command", True)
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS'))
commands_to_run.extend(
_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')
)
elif data["eventKey"] == "pr:comment:added":
commands_to_run.append(data["comment"]["text"])
else:
@ -116,6 +126,7 @@ async def _run_commands_sequentially(commands: List[str], url: str, log_context:
except Exception as e:
get_logger().error(f"Failed to handle command: {command} , error: {e}")
def _process_command(command: str, url) -> str:
# don't think we need this
apply_repo_settings(url)
@ -142,11 +153,13 @@ def _to_list(command_string: str) -> list:
raise ValueError(f"Invalid command string: {e}")
def _get_commands_list_from_settings(setting_key:str ) -> list:
def _get_commands_list_from_settings(setting_key: str) -> list:
try:
return get_settings().get(setting_key, [])
except ValueError as e:
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}")
get_logger().error(
f"Failed to get commands list from settings {setting_key}: {e}"
)
@router.get("/")

View File

@ -40,12 +40,10 @@ async def handle_gerrit_request(action: Action, item: Item):
if action == Action.ask:
if not item.msg:
return HTTPException(
status_code=400,
detail="msg is required for ask command"
status_code=400, detail="msg is required for ask command"
)
await PRAgent().handle_request(
f"{item.project}:{item.refspec}",
f"/{item.msg.strip()}"
f"{item.project}:{item.refspec}", f"/{item.msg.strip()}"
)

View File

@ -26,7 +26,12 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str,
try:
value = get_settings().get(key, default)
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default
value = (
os.getenv(key, None)
or os.getenv(key.upper(), None)
or os.getenv(key.lower(), None)
or default
)
return value
@ -76,16 +81,24 @@ async def run_action():
pr_url = event_payload.get("pull_request", {}).get("html_url")
if pr_url:
apply_repo_settings(pr_url)
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}")
get_logger().info(
f"enable_custom_labels: {get_settings().config.enable_custom_labels}"
)
except Exception as e:
get_logger().info(f"github action: failed to apply repo settings: {e}")
# Handle pull request opened event
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target":
if (
GITHUB_EVENT_NAME == "pull_request"
or GITHUB_EVENT_NAME == "pull_request_target"
):
action = event_payload.get("action")
# Retrieve the list of actions from the configuration
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
pr_actions = get_settings().get(
"GITHUB_ACTION_CONFIG.PR_ACTIONS",
["opened", "reopened", "ready_for_review", "review_requested"],
)
if action in pr_actions:
pr_url = event_payload.get("pull_request", {}).get("url")
@ -93,18 +106,30 @@ async def run_action():
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
if auto_review is None:
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
auto_review = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_REVIEW", None
)
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
if auto_describe is None:
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
auto_describe = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None
)
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
if auto_improve is None:
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
auto_improve = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None
)
# Set the configuration for auto actions
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
get_settings().config.is_auto_command = (
True # Set the flag to indicate that the command is auto
)
get_settings().pr_description.final_update_message = (
False # No final update message when auto_describe is enabled
)
get_logger().info(
f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}"
)
# invoke by default all three tools
if auto_describe is None or is_true(auto_describe):
@ -117,7 +142,10 @@ async def run_action():
get_logger().info(f"Skipping action: {action}")
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment":
elif (
GITHUB_EVENT_NAME == "issue_comment"
or GITHUB_EVENT_NAME == "pull_request_review_comment"
):
action = event_payload.get("action")
if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body")
@ -133,9 +161,15 @@ async def run_action():
disable_eyes = False
# check if issue is pull request
if event_payload.get("issue", {}).get("pull_request"):
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
url = (
event_payload.get("issue", {})
.get("pull_request", {})
.get("url")
)
is_pr = True
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
elif event_payload.get("comment", {}).get(
"pull_request_url"
): # for 'pull_request_review_comment
url = event_payload.get("comment", {}).get("pull_request_url")
is_pr = True
disable_eyes = True
@ -148,9 +182,11 @@ async def run_action():
provider = get_git_provider()(pr_url=url)
if is_pr:
await PRAgent().handle_request(
url, body, notify=lambda: provider.add_eyes_reaction(
url,
body,
notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
)
),
)
else:
await PRAgent().handle_request(url, body)

View File

@ -15,8 +15,7 @@ from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.git_providers import (get_git_provider,
get_git_provider_with_context)
from utils.pr_agent.git_providers import get_git_provider, get_git_provider_with_context
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.identity_providers import get_identity_provider
from utils.pr_agent.identity_providers.identity_provider import Eligibility
@ -35,7 +34,9 @@ router = APIRouter()
@router.post("/api/v1/github_webhooks")
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response):
async def handle_github_webhooks(
background_tasks: BackgroundTasks, request: Request, response: Response
):
"""
Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
@ -49,7 +50,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings)
context["git_provider"] = {}
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
background_tasks.add_task(
handle_request, body, event=request.headers.get("X-GitHub-Event", None)
)
return {}
@ -73,35 +76,61 @@ async def get_body(request):
return body
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
_duplicate_push_triggers = DefaultDictWithTimeout(
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl
)
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(
asyncio.locks.Condition,
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl,
)
async def handle_comments_on_pr(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
async def handle_comments_on_pr(
body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent,
):
if "comment" not in body:
return {}
comment_body = body.get("comment", {}).get("body")
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
if (
comment_body
and isinstance(comment_body, str)
and not comment_body.lstrip().startswith("/")
):
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
comment_body_split = comment_body.split('/ask')
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>')
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
comment_body = (
'/ask'
+ comment_body_split[1]
+ ' \n'
+ comment_body_split[0].strip().lstrip('>')
)
get_logger().info(
f"Reformatting comment_body so command is at the beginning: {comment_body}"
)
else:
get_logger().info("Ignoring comment not starting with /")
return {}
disable_eyes = False
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
if (
"issue" in body
and "pull_request" in body["issue"]
and "url" in body["issue"]["pull_request"]
):
api_url = body["issue"]["pull_request"]["url"]
elif "comment" in body and "pull_request_url" in body["comment"]:
api_url = body["comment"]["pull_request_url"]
try:
if ('/ask' in comment_body and
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"):
if (
'/ask' in comment_body
and 'subject_type' in body["comment"]
and body["comment"]["subject_type"] == "line"
):
# comment on a code line in the "files changed" tab
comment_body = handle_line_comments(body, comment_body)
disable_eyes = True
@ -113,46 +142,75 @@ async def handle_comments_on_pr(body: Dict[str, Any],
comment_id = body.get("comment", {}).get("id")
provider = get_git_provider_with_context(pr_url=api_url)
with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
await agent.handle_request(api_url, comment_body,
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
if (
get_identity_provider().verify_eligibility("github", sender_id, api_url)
is not Eligibility.NOT_ELIGIBLE
):
get_logger().info(
f"Processing comment on PR {api_url=}, comment_body={comment_body}"
)
await agent.handle_request(
api_url,
comment_body,
notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
),
)
else:
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}")
get_logger().info(
f"User {sender=} is not eligible to process comment on PR {api_url=}"
)
async def handle_new_pr_opened(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
async def handle_new_pr_opened(
body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent,
):
title = body.get("pull_request", {}).get("title", "")
pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url):
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
return {}
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
if (
action in get_settings().github_app.handle_pr_actions
): # ['opened', 'reopened', 'ready_for_review']
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
apply_repo_settings(api_url)
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
if (
get_identity_provider().verify_eligibility("github", sender_id, api_url)
is not Eligibility.NOT_ELIGIBLE
):
await _perform_auto_commands_github(
"pr_commands", agent, body, api_url, log_context
)
else:
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
get_logger().info(
f"User {sender=} is not eligible to process PR {api_url=}"
)
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
async def handle_push_trigger_for_new_commits(
body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent,
):
pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url):
return {}
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
apply_repo_settings(
api_url
) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
if not get_settings().github_app.handle_push_trigger:
return {}
@ -162,7 +220,10 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
merge_commit_sha = pull_request.get("merge_commit_sha")
if before_sha == after_sha:
return {}
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha:
if (
get_settings().github_app.push_trigger_ignore_merge_commits
and after_sha == merge_commit_sha
):
return {}
# Prevent triggering multiple times for subsequent push triggers when one is enough:
@ -172,7 +233,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
# more commits may have been pushed that led to the subsequent events,
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
max_active_tasks = (
2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
)
if current_active_tasks < max_active_tasks:
# first task can enter, and second tasks too if backlog is enabled
get_logger().info(
@ -191,12 +254,21 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
)
await _pending_task_duplicate_push_conditions[api_url].wait()
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow")
get_logger().info(
f"Finished waiting to process push trigger for {api_url=} - continue with flow"
)
try:
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
if (
get_identity_provider().verify_eligibility("github", sender_id, api_url)
is not Eligibility.NOT_ELIGIBLE
):
get_logger().info(
f"Performing incremental review for {api_url=} because of {event=} and {action=}"
)
await _perform_auto_commands_github(
"push_commands", agent, body, api_url, log_context
)
finally:
# release the waiting task block
@ -213,7 +285,12 @@ def handle_closed_pr(body, event, action, log_context):
api_url = pull_request.get("url", "")
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
log_context["api_url"] = api_url
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context)
get_logger().info(
"PR-Agent statistics for closed PR",
analytics=True,
pr_statistics=pr_statistics,
**log_context,
)
def get_log_context(body, event, action, build_number):
@ -228,9 +305,18 @@ def get_log_context(body, event, action, build_number):
git_org = body.get("organization", {}).get("login", "")
installation_id = body.get("installation", {}).get("id", "")
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
"repo": repo, "git_org": git_org, "installation_id": installation_id}
log_context = {
"action": action,
"event": event,
"sender": sender,
"server_type": "github_app",
"request_id": uuid.uuid4().hex,
"build_number": build_number,
"app_name": app_name,
"repo": repo,
"git_org": git_org,
"installation_id": installation_id,
}
except Exception as e:
get_logger().error("Failed to get log context", e)
log_context = {}
@ -240,7 +326,10 @@ def get_log_context(body, event, action, build_number):
def is_bot_user(sender, sender_type):
try:
# logic to ignore PRs opened by bot
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot":
if (
get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False)
and sender_type == "Bot"
):
if 'pr-agent' not in sender:
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
return True
@ -262,7 +351,9 @@ def should_process_pr_logic(body) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
)
return False
# logic to ignore PRs with specific titles
@ -270,8 +361,12 @@ def should_process_pr_logic(body) -> bool:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
if ignore_pr_title_re and any(
re.search(regex, title) for regex in ignore_pr_title_re
):
get_logger().info(
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
)
return False
# logic to ignore PRs with specific labels or source branches or target branches.
@ -280,20 +375,32 @@ def should_process_pr_logic(body) -> bool:
labels = [label['name'] for label in pr_labels]
if any(label in ignore_pr_labels for label in labels):
labels_str = ", ".join(labels)
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings")
get_logger().info(
f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings"
)
return False
# logic to ignore PRs with specific source or target branches
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
ignore_pr_source_branches = get_settings().get(
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
)
ignore_pr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
if any(
re.search(regex, source_branch) for regex in ignore_pr_source_branches
):
get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
)
return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
if any(
re.search(regex, target_branch) for regex in ignore_pr_target_branches
):
get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
)
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -308,11 +415,15 @@ async def handle_request(body: Dict[str, Any], event: str):
body: The request body.
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
"""
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
action = body.get(
"action"
) # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
if not action:
return {}
agent = PRAgent()
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number)
log_context, sender, sender_id, sender_type = get_log_context(
body, event, action, build_number
)
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
if is_bot_user(sender, sender_type) and 'check_run' not in body:
@ -327,21 +438,29 @@ async def handle_request(body: Dict[str, Any], event: str):
# handle comments on PRs
elif action == 'created':
get_logger().debug(f'Request body', artifact=body, event=event)
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
await handle_comments_on_pr(
body, event, sender, sender_id, action, log_context, agent
)
# handle new PRs
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
get_logger().debug(f'Request body', artifact=body, event=event)
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
await handle_new_pr_opened(
body, event, sender, sender_id, action, log_context, agent
)
elif event == "issue_comment" and 'edited' in action:
pass # handle_checkbox_clicked
pass # handle_checkbox_clicked
# handle pull_request event with synchronize action - "push trigger" for new commits
elif event == 'pull_request' and action == 'synchronize':
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
await handle_push_trigger_for_new_commits(
body, event, sender, sender_id, action, log_context, agent
)
elif event == 'pull_request' and action == 'closed':
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
handle_closed_pr(body, event, action, log_context)
else:
get_logger().info(f"event {event=} action {action=} does not require any handling")
get_logger().info(
f"event {event=} action {action=} does not require any handling"
)
return {}
@ -362,7 +481,9 @@ def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
return comment_body
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]:
def _check_pull_request_event(
action: str, body: dict, log_context: dict
) -> Tuple[Dict[str, Any], str]:
invalid_result = {}, ""
pull_request = body.get("pull_request")
if not pull_request:
@ -373,19 +494,28 @@ def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tup
log_context["api_url"] = api_url
if pull_request.get("draft", True) or pull_request.get("state") != "open":
return invalid_result
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"):
if action in ("review_requested", "synchronize") and pull_request.get(
"created_at"
) == pull_request.get("updated_at"):
# avoid double reviews when opening a PR for the first time
return invalid_result
return pull_request, api_url
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str,
log_context: dict):
async def _perform_auto_commands_github(
commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict
):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
if (
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
)
return
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings
if not should_process_pr_logic(
body
): # Here we already updated the configuration with the repo settings
return {}
commands = get_settings().get(f"github_app.{commands_conf}")
if not commands:
@ -398,7 +528,9 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
get_logger().info(
f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}"
)
await agent.handle_request(api_url, new_command)

View File

@ -18,11 +18,13 @@ NOTIFICATION_URL = "https://api.github.com/notifications"
async def mark_notification_as_read(headers, notification, session):
async with session.patch(
f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers) as mark_read_response:
f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers,
) as mark_read_response:
if mark_read_response.status != 205:
get_logger().error(
f"Failed to mark notification as read. Status code: {mark_read_response.status}")
f"Failed to mark notification as read. Status code: {mark_read_response.status}"
)
def now() -> str:
@ -36,17 +38,21 @@ def now() -> str:
now_utc = now_utc.replace("+00:00", "Z")
return now_utc
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
notify=lambda: git_provider.add_eyes_reaction(comment_id),
)
return success
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))
return asyncio.run(
async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
)
def process_comment_sync(pr_url, rest_of_comment, comment_id):
@ -55,7 +61,10 @@ def process_comment_sync(pr_url, rest_of_comment, comment_id):
git_provider = get_git_provider()(pr_url=pr_url)
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error processing comment: {e}",
artifact={"traceback": traceback.format_exc()},
)
async def process_comment(pr_url, rest_of_comment, comment_id):
@ -66,22 +75,31 @@ async def process_comment(pr_url, rest_of_comment, comment_id):
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
notify=lambda: git_provider.add_eyes_reaction(comment_id),
)
get_logger().info(f"Finished processing comment for PR: {pr_url}")
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error processing comment: {e}",
artifact={"traceback": traceback.format_exc()},
)
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
try:
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
if (
'subject' in notification
and notification['subject']['type'] == 'PullRequest'
):
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
get_logger().debug(f"no latest_comment")
return False, handled_ids
async with session.get(latest_comment, headers=headers) as comment_response:
async with session.get(
latest_comment, headers=headers
) as comment_response:
check_prev_comments = False
user_tag = "@" + user_id
if comment_response.status == 200:
@ -94,7 +112,9 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
get_logger().debug(f"comment['user']['login'] == user_id")
get_logger().debug(
f"comment['user']['login'] == user_id"
)
check_prev_comments = True
comment_body = comment.get('body', '')
if not comment_body:
@ -105,15 +125,28 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
get_logger().debug(f"user_tag not in comment_body")
check_prev_comments = True
else:
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
get_logger().info(
f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body},
)
if not check_prev_comments:
return True, handled_ids, comment, comment_body, pr_url, user_tag
else: # we could not find the user tag in the latest comment. Check previous comments
return (
True,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
)
else: # we could not find the user tag in the latest comment. Check previous comments
# get all comments in the PR
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
comments_response = requests.get(requests_url, headers=headers)
requests_url = f"{pr_url}/comments".replace(
"pulls", "issues"
)
comments_response = requests.get(
requests_url, headers=headers
)
comments = comments_response.json()[::-1]
max_comment_to_scan = 4
for comment in comments[:max_comment_to_scan]:
@ -124,23 +157,37 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
if not comment_body:
continue
if user_tag in comment_body:
get_logger().info("found user tag in previous comments")
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
return True, handled_ids, comment, comment_body, pr_url, user_tag
get_logger().info(
"found user tag in previous comments"
)
get_logger().info(
f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body},
)
return (
True,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
)
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}",
artifact={"comments": comments})
get_logger().warning(
f"Failed to fetch comments for PR: {pr_url}",
artifact={"comments": comments},
)
return False, handled_ids
return False, handled_ids
except Exception as e:
get_logger().exception(f"Error processing polling notification",
artifact={"notification": notification, "error": e})
get_logger().exception(
f"Error processing polling notification",
artifact={"notification": notification, "error": e},
)
return False, handled_ids
async def polling_loop():
"""
Polls for notifications and handles them accordingly.
@ -171,17 +218,17 @@ async def polling_loop():
await asyncio.sleep(5)
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {token}"
}
params = {
"participating": "true"
"Authorization": f"Bearer {token}",
}
params = {"participating": "true"}
if since[0]:
params["since"] = since[0]
if last_modified[0]:
headers["If-Modified-Since"] = last_modified[0]
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
async with session.get(
NOTIFICATION_URL, headers=headers, params=params
) as response:
if response.status == 200:
if 'Last-Modified' in response.headers:
last_modified[0] = response.headers['Last-Modified']
@ -189,39 +236,67 @@ async def polling_loop():
notifications = await response.json()
if not notifications:
continue
get_logger().info(f"Received {len(notifications)} notifications")
get_logger().info(
f"Received {len(notifications)} notifications"
)
task_queue = deque()
for notification in notifications:
if not notification:
continue
# mark notification as read
await mark_notification_as_read(headers, notification, session)
await mark_notification_as_read(
headers, notification, session
)
handled_ids.add(notification['id'])
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
output = await is_valid_notification(
notification, headers, handled_ids, session, user_id
)
if output[0]:
_, handled_ids, comment, comment_body, pr_url, user_tag = output
rest_of_comment = comment_body.split(user_tag)[1].strip()
(
_,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
) = output
rest_of_comment = comment_body.split(user_tag)[
1
].strip()
comment_id = comment['id']
# Add to the task queue
get_logger().info(
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
get_logger().info(f"Queued comment processing for PR: {pr_url}")
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}"
)
task_queue.append(
(
process_comment_sync,
(pr_url, rest_of_comment, comment_id),
)
)
get_logger().info(
f"Queued comment processing for PR: {pr_url}"
)
else:
get_logger().debug(f"Skipping comment processing for PR")
get_logger().debug(
f"Skipping comment processing for PR"
)
max_allowed_parallel_tasks = 10
if task_queue:
processes = []
for i, (func, args) in enumerate(task_queue): # Create parallel tasks
for i, (func, args) in enumerate(
task_queue
): # Create parallel tasks
p = multiprocessing.Process(target=func, args=args)
processes.append(p)
p.start()
if i > max_allowed_parallel_tasks:
get_logger().error(
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session"
)
break
task_queue.clear()
@ -230,11 +305,15 @@ async def polling_loop():
# p.join()
elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}")
print(
f"Failed to fetch notifications. Status code: {response.status}"
)
except Exception as e:
get_logger().error(f"Polling exception during processing of a notification: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Polling exception during processing of a notification: {e}",
artifact={"traceback": traceback.format_exc()},
)
if __name__ == '__main__':

View File

@ -22,20 +22,21 @@ from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
secret_provider = (
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
)
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
try:
import requests
headers = {
'Private-Token': f'{gitlab_token}'
}
headers = {'Private-Token': f'{gitlab_token}'}
# API endpoint to find MRs containing the commit
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
response = requests.get(
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
headers=headers
headers=headers,
)
merge_requests = response.json()
if merge_requests and response.status_code == 200:
@ -48,6 +49,7 @@ async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
get_logger().error(f"Failed to get MR url from commit sha: {e}")
return None
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
log_context["action"] = body
log_context["event"] = "pull_request" if body == "/review" else "comment"
@ -58,13 +60,19 @@ async def handle_request(api_url: str, body: str, log_context: dict, sender_id:
await PRAgent().handle_request(api_url, body)
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str,
log_context: dict, data: dict):
async def _perform_commands_gitlab(
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
if (
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
**log_context,
)
return
if not should_process_pr_logic(data): # Here we already updated the configurations
if not should_process_pr_logic(data): # Here we already updated the configurations
return
commands = get_settings().get(f"gitlab.{commands_conf}", {})
get_settings().set("config.is_auto_command", True)
@ -106,40 +114,58 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings")
get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings"
)
return False
# logic to ignore MRs for titles, labels and source, target branches.
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
ignore_mr_source_branches = get_settings().get(
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
)
ignore_mr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
#
if ignore_mr_source_branches:
source_branch = data['object_attributes'].get('source_branch')
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches):
if any(
re.search(regex, source_branch) for regex in ignore_mr_source_branches
):
get_logger().info(
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings")
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings"
)
return False
if ignore_mr_target_branches:
target_branch = data['object_attributes'].get('target_branch')
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches):
if any(
re.search(regex, target_branch) for regex in ignore_mr_target_branches
):
get_logger().info(
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings")
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings"
)
return False
if ignore_mr_labels:
labels = [label['title'] for label in data['object_attributes'].get('labels', [])]
labels = [
label['title'] for label in data['object_attributes'].get('labels', [])
]
if any(label in ignore_mr_labels for label in labels):
labels_str = ", ".join(labels)
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings")
get_logger().info(
f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings"
)
return False
if ignore_mr_title:
if any(re.search(regex, title) for regex in ignore_mr_title):
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings")
get_logger().info(
f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings"
)
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -159,29 +185,47 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
request_token = request.headers.get("X-Gitlab-Token")
secret = secret_provider.get_secret(request_token)
if not secret:
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}))
get_logger().warning(
f"Empty secret retrieved, request_token: {request_token}"
)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
try:
secret_dict = json.loads(secret)
gitlab_token = secret_dict["gitlab_token"]
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown"))
log_context["token_id"] = secret_dict.get(
"token_name", secret_dict.get("id", "unknown")
)
context["settings"].gitlab.personal_access_token = gitlab_token
except Exception as e:
get_logger().error(f"Failed to validate secret {request_token}: {e}")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
elif get_settings().get("GITLAB.SHARED_SECRET"):
secret = get_settings().get("GITLAB.SHARED_SECRET")
if not request.headers.get("X-Gitlab-Token") == secret:
get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
else:
get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token:
get_logger().error("No gitlab token found")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
get_logger().info("GitLab data", artifact=data)
sender = data.get("user", {}).get("username", "unknown")
@ -189,31 +233,49 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
# ignore bot users
if is_bot_user(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
if data.get('event_type') != 'note': # not a comment
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
if data.get('event_type') != 'note': # not a comment
# ignore MRs based on title, labels, source and target branches
if not should_process_pr_logic(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
log_context["sender"] = sender
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get(
'action'
) in ['open', 'reopen']:
title = data['object_attributes'].get('title')
url = data['object_attributes'].get('url')
draft = data['object_attributes'].get('draft')
get_logger().info(f"New merge request: {url}")
if draft:
get_logger().info(f"Skipping draft MR: {url}")
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data)
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR
await _perform_commands_gitlab(
"pr_commands", PRAgent(), url, log_context, data
)
elif (
data.get('object_kind') == 'note' and data.get('event_type') == 'note'
): # comment on MR
if 'merge_request' in data:
mr = data['merge_request']
url = mr.get('url')
get_logger().info(f"A comment has been added to a merge request: {url}")
body = data.get('object_attributes', {}).get('note')
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line
if (
data.get('object_attributes', {}).get('type') == 'DiffNote'
and '/ask' in body
): # /ask_line
body = handle_ask_line(body, data)
await handle_request(url, body, log_context, sender_id)
@ -221,30 +283,44 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
try:
project_id = data['project_id']
commit_sha = data['checkout_sha']
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id)
url = await get_mr_url_from_commit_sha(
commit_sha, gitlab_token, project_id
)
if not url:
get_logger().info(f"No MR found for commit: {commit_sha}")
return JSONResponse(status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}))
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
# we need first to apply_repo_settings
apply_repo_settings(url)
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False)
handle_push_trigger = get_settings().get(
f"gitlab.handle_push_trigger", False
)
if not commands_on_push or not handle_push_trigger:
get_logger().info("Push event, but no push commands found or push trigger is disabled")
return JSONResponse(status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}))
get_logger().info(
"Push event, but no push commands found or push trigger is disabled"
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
get_logger().debug(f'A push event has been received: {url}')
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data)
await _perform_commands_gitlab(
"push_commands", PRAgent(), url, log_context, data
)
except Exception as e:
get_logger().error(f"Failed to handle push event: {e}")
background_tasks.add_task(inner, request_json)
end_time = datetime.now()
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
)
def handle_ask_line(body, data):
@ -271,6 +347,7 @@ def handle_ask_line(body, data):
async def root():
return {"status": "ok"}
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GITLAB.URL is not set")

View File

@ -1,18 +1,19 @@
class HelpMessage:
@staticmethod
def get_general_commands_text():
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \
">To list the possible configuration parameters, add a **/config** comment. \n"
return commands_text
commands_text = (
"> - **/review**: Request a review of your Pull Request. \n"
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n"
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n"
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n"
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n"
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n"
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n"
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n"
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n"
">To list the possible configuration parameters, add a **/config** comment. \n"
)
return commands_text
@staticmethod
def get_general_bot_help_text():
@ -21,10 +22,12 @@ class HelpMessage:
@staticmethod
def get_review_usage_guide():
output ="**Overview:**\n"
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n")
output +="""\
output = "**Overview:**\n"
output += (
"The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n"
)
output += """\
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
```
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
@ -41,8 +44,6 @@ some_config2=...
return output
@staticmethod
def get_describe_usage_guide():
output = "**Overview:**\n"
@ -137,7 +138,6 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th
'''
output += "\n\n</details></td></tr>\n\n"
# general
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
output += HelpMessage.get_general_bot_help_text()
@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a
return output
@staticmethod
def get_improve_usage_guide():
output = "**Overview:**\n"

View File

@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header):
signature_header: header received from GitHub (x-hub-signature-256)
"""
if not signature_header:
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
raise HTTPException(
status_code=403, detail="x-hub-signature-256 header is missing!"
)
hash_object = hmac.new(
secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256
)
expected_signature = "sha256=" + hash_object.hexdigest()
if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header):
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass
@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict):
request_time = self.__time()
if request_time - self.__last_refresh > self.__refresh_interval:
return
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
to_delete = [
key
for key, key_time in self.__key_times.items()
if request_time - key_time > self.__ttl
]
for key in to_delete:
del self[key]
self.__last_refresh = request_time

View File

@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger
class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
cli_mode=False,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
@ -39,13 +43,16 @@ class PRAddDocs:
"diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
'docs_for_language': get_docs_for_language(self.main_language,
get_settings().pr_add_docs.docs_style),
'docs_for_language': get_docs_for_language(
self.main_language, get_settings().pr_add_docs.docs_style
),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_add_docs_prompt.system,
get_settings().pr_add_docs_prompt.user)
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_add_docs_prompt.system,
get_settings().pr_add_docs_prompt.user,
)
async def run(self):
try:
@ -66,16 +73,20 @@ class PRAddDocs:
get_logger().info('Pushing inline code documentation...')
self.push_inline_docs(data)
except Exception as e:
get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
get_logger().error(
f"Failed to generate code documentation for PR, error: {e}"
)
async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False)
self.patches_diff = get_pr_diff(
self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False,
)
get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
@ -84,13 +95,21 @@ class PRAddDocs:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_add_docs_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_add_docs_prompt.user
).render(variables)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
get_logger().info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response
@ -105,7 +124,9 @@ class PRAddDocs:
docs = []
if not data['Code Documentation']:
return self.git_provider.publish_comment('No code documentation found to improve this PR.')
return self.git_provider.publish_comment(
'No code documentation found to improve this PR.'
)
for d in data['Code Documentation']:
try:
@ -116,32 +137,59 @@ class PRAddDocs:
documentation = d['documentation']
doc_placement = d['doc placement'].strip()
if documentation:
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
add_original_line=True)
new_code_snippet = self.dedent_code(
relevant_file,
relevant_line,
documentation,
doc_placement,
add_original_line=True,
)
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
docs.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_line,
'relevant_lines_end': relevant_line})
body = (
f"**Suggestion:** Proposed documentation\n```suggestion\n"
+ new_code_snippet
+ "\n```"
)
docs.append(
{
'body': body,
'relevant_file': relevant_file,
'relevant_lines_start': relevant_line,
'relevant_lines_end': relevant_line,
}
)
except Exception:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not parse code docs: {d}")
is_successful = self.git_provider.publish_code_suggestions(docs)
if not is_successful:
get_logger().info("Failed to publish code docs, trying to publish each docs separately")
get_logger().info(
"Failed to publish code docs, trying to publish each docs separately"
)
for doc_suggestion in docs:
self.git_provider.publish_code_suggestions([doc_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
add_original_line=False):
def dedent_code(
self,
relevant_file,
relevant_lines_start,
new_code_snippet,
doc_placement='after',
add_original_line=False,
):
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
self.diff_files = (
self.git_provider.diff_files
if self.git_provider.diff_files
else self.git_provider.get_diff_files()
)
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
original_initial_line = file.head_file.splitlines()[
relevant_lines_start - 1
]
break
if original_initial_line:
if doc_placement == 'after':
@ -150,18 +198,28 @@ class PRAddDocs:
line = original_initial_line
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(line) - len(line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(
suggested_initial_line.lstrip()
)
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
new_code_snippet = textwrap.indent(
new_code_snippet, delta_spaces * " "
).rstrip('\n')
if add_original_line:
if doc_placement == 'after':
new_code_snippet = original_initial_line + "\n" + new_code_snippet
new_code_snippet = (
original_initial_line + "\n" + new_code_snippet
)
else:
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
new_code_snippet = (
new_code_snippet.rstrip() + "\n" + original_initial_line
)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
get_logger().info(
f"Could not dedent code snippet for file {relevant_file}, error: {e}"
)
return new_code_snippet

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ class PRConfig:
"""
The PRConfig class is responsible for listing all configuration options available for the user.
"""
def __init__(self, pr_url: str, args=None, ai_handler=None):
"""
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
@ -34,20 +35,43 @@ class PRConfig:
conf_settings = Dynaconf(settings_files=[conf_file])
configuration_headers = [header.lower() for header in conf_settings.keys()]
relevant_configs = {
header: configs for header, configs in get_settings().to_dict().items()
if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers
header: configs
for header, configs in get_settings().to_dict().items()
if (header.lower().startswith("pr_") or header.lower().startswith("config"))
and header.lower() in configuration_headers
}
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS',
'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token',
'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token']
skip_keys = [
'ai_disclaimer',
'ai_disclaimer_title',
'ANALYTICS_FOLDER',
'secret_provider',
"skip_keys",
"app_id",
"redirect",
'trial_prefix_message',
'no_eligible_message',
'identity_provider',
'ALLOWED_REPOS',
'APP_NAME',
'PERSONAL_ACCESS_TOKEN',
'shared_secret',
'key',
'AWS_ACCESS_KEY_ID',
'AWS_SECRET_ACCESS_KEY',
'user_token',
'private_key',
'private_key_id',
'client_id',
'client_secret',
'token',
'bearer_token',
]
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys:
skip_keys.extend(extra_skip_keys)
skip_keys_lower = [key.lower() for key in skip_keys]
markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n"
markdown_text += f"\n\n```yaml\n\n"
for header, configs in relevant_configs.items():
@ -61,5 +85,7 @@ class PRConfig:
markdown_text += " "
markdown_text += "\n```"
markdown_text += "\n</details>\n"
get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text)
get_logger().info(
f"Possible Configurations outputted to PR comment", artifact=markdown_text
)
return markdown_text

View File

@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
get_pr_diff,
get_pr_diff_multiple_patchs,
retry_with_fallback_models)
from utils.pr_agent.algo.pr_processing import (
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
get_pr_diff,
get_pr_diff_multiple_patchs,
retry_with_fallback_models,
)
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens,
get_max_tokens, get_user_labels, load_yaml,
set_custom_labels,
show_relevant_configurations)
from utils.pr_agent.algo.utils import (
ModelType,
PRDescriptionHeader,
clip_tokens,
get_max_tokens,
get_user_labels,
load_yaml,
set_custom_labels,
show_relevant_configurations,
)
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context)
from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
from utils.pr_agent.tools.ticket_pr_compliance_check import (
extract_and_cache_pr_tickets)
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
class PRDescription:
def __init__(self, pr_url: str, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
@ -44,11 +55,22 @@ class PRDescription:
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.pr_id = self.git_provider.get_pr_id()
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
self.keys_fix = [
"filename:",
"language:",
"changes_summary:",
"changes_title:",
"description:",
"title:",
]
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
"gfm_markdown"):
get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.")
if (
get_settings().pr_description.enable_semantic_files_types
and not self.git_provider.is_supported("gfm_markdown")
):
get_logger().debug(
f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported."
)
get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler
@ -56,7 +78,9 @@ class PRDescription:
self.ai_handler.main_pr_language = self.main_pr_language
# Initialize the variables dictionary
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8)
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get(
"collapsible_file_list_threshold", 8
)
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
@ -69,8 +93,11 @@ class PRDescription:
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": "",
"include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
"include_file_summary_changes": len(self.git_provider.get_diff_files())
<= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
'duplicate_prompt_examples': get_settings().config.get(
'duplicate_prompt_examples', False
),
}
self.user_description = self.git_provider.get_user_description()
@ -91,10 +118,14 @@ class PRDescription:
async def run(self):
try:
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
relevant_configs = {'pr_description': dict(get_settings().pr_description),
'config': dict(get_settings().config)}
relevant_configs = {
'pr_description': dict(get_settings().pr_description),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
if get_settings().config.publish_output and not get_settings().config.get(
'is_auto_command', False
):
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
# ticket extraction if exists
@ -119,40 +150,73 @@ class PRDescription:
get_logger().debug(f"Publishing labels disabled")
if get_settings().pr_description.use_description_markers:
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers()
(
pr_title,
pr_body,
changes_walkthrough,
pr_file_changes,
) = self._prepare_pr_answer_with_markers()
else:
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer()
if not self.git_provider.is_supported(
"publish_file_comments") or not get_settings().pr_description.inline_file_summary:
(
pr_title,
pr_body,
changes_walkthrough,
pr_file_changes,
) = self._prepare_pr_answer()
if (
not self.git_provider.is_supported("publish_file_comments")
or not get_settings().pr_description.inline_file_summary
):
pr_body += "\n\n" + changes_walkthrough
get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body})
get_logger().debug(
"PR output", artifact={"title": pr_title, "body": pr_body}
)
# Add help text if gfm_markdown is supported
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text:
if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_description.enable_help_text
):
pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n"
pr_body += HelpMessage.get_describe_usage_guide()
pr_body += "\n</details>\n"
elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"):
elif (
get_settings().pr_description.enable_help_comment
and self.git_provider.is_supported("gfm_markdown")
):
if isinstance(self.git_provider, GithubProvider):
pr_body += ('\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
'了解更多.</li></details>')
else: # gitlab
pr_body += ("\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>")
pr_body += (
'\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
'了解更多.</li></details>'
)
else: # gitlab
pr_body += (
"\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>"
)
# elif get_settings().pr_description.enable_help_comment:
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
# Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False):
pr_body += show_relevant_configurations(relevant_section='pr_description')
if (
get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
pr_body += show_relevant_configurations(
relevant_section='pr_description'
)
if get_settings().config.publish_output:
# publish labels
if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"):
if (
get_settings().pr_description.publish_labels
and pr_labels
and self.git_provider.is_supported("get_labels")
):
original_labels = self.git_provider.get_pr_labels(update=True)
get_logger().debug(f"original labels", artifact=original_labels)
user_labels = get_user_labels(original_labels)
@ -165,20 +229,29 @@ class PRDescription:
# publish description
if get_settings().pr_description.publish_description_as_comment:
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
if get_settings().pr_description.publish_description_as_comment_persistent:
self.git_provider.publish_persistent_comment(full_markdown_description,
initial_header="## Title",
update_header=True,
name="describe",
final_update_message=False, )
full_markdown_description = (
f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
)
if (
get_settings().pr_description.publish_description_as_comment_persistent
):
self.git_provider.publish_persistent_comment(
full_markdown_description,
initial_header="## Title",
update_header=True,
name="describe",
final_update_message=False,
)
else:
self.git_provider.publish_comment(full_markdown_description)
else:
self.git_provider.publish_description(pr_title, pr_body)
# publish final update message
if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)):
if (
get_settings().pr_description.final_update_message
and not get_settings().config.get('is_auto_command', False)
):
latest_commit_url = self.git_provider.get_latest_commit_url()
if latest_commit_url:
pr_url = self.git_provider.get_pr_url()
@ -186,22 +259,40 @@ class PRDescription:
self.git_provider.publish_comment(update_comment)
self.git_provider.remove_initial_comment()
else:
get_logger().info('PR description, but not published since publish_output is False.')
get_logger().info(
'PR description, but not published since publish_output is False.'
)
get_settings().data = {"artifact": pr_body}
return
except Exception as e:
get_logger().error(f"Error generating PR description {self.pr_id}: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error generating PR description {self.pr_id}: {e}",
artifact={"traceback": traceback.format_exc()},
)
return ""
async def _prepare_prediction(self, model: str) -> None:
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction")
if (
get_settings().pr_description.use_description_markers
and 'pr_agent:' not in self.user_description
):
get_logger().info(
"Markers were enabled, but user description does not contain markers. skipping AI prediction"
)
return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
large_pr_handling = (
get_settings().pr_description.enable_large_pr_handling
and "pr_description_only_files_prompts" in get_settings()
)
output = get_pr_diff(
self.git_provider,
self.token_handler,
model,
large_pr_handling=large_pr_handling,
return_remaining_files=True,
)
if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
@ -213,14 +304,18 @@ class PRDescription:
if patches_diff:
# generate the prediction
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
self.prediction = await self._get_prediction(
model, patches_diff, prompt="pr_description_prompt"
)
# extend the prediction with additional files not shown
if get_settings().pr_description.enable_semantic_files_types:
self.prediction = await self.extend_uncovered_files(self.prediction)
else:
get_logger().error(f"Error getting PR diff {self.pr_id}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error getting PR diff {self.pr_id}",
artifact={"traceback": traceback.format_exc()},
)
self.prediction = None
else:
# get the diff in multiple patches, with the token handler only for the files prompt
@ -231,9 +326,16 @@ class PRDescription:
get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user,
)
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
files_in_patches_list) = get_pr_diff_multiple_patchs(
self.git_provider, token_handler_only_files_prompt, model)
(
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = get_pr_diff_multiple_patchs(
self.git_provider, token_handler_only_files_prompt, model
)
# get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
@ -241,8 +343,9 @@ class PRDescription:
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = await self._get_prediction(
model, patches_diff, prompt="pr_description_only_files_prompts"
)
results.append(prediction_files)
else: # async calls
tasks = []
@ -251,34 +354,52 @@ class PRDescription:
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
self._get_prediction(
model,
patches_diff,
prompt="pr_description_only_files_prompts",
)
)
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = []
for i, result in enumerate(results):
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip()
prediction_files = (
result.strip().removeprefix('```yaml').strip('`').strip()
)
if load_yaml(
prediction_files, keys_fix_yaml=self.keys_fix
) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix(
'pr_files:'
).strip()
file_description_str_list.append(prediction_files)
else:
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
get_logger().debug(
f"failed to generate predictions in iteration {i + 1} for describe files"
)
# generate files_walkthrough string, with proper token handling
token_handler_only_description_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_description_prompts.system,
get_settings().pr_description_only_description_prompts.user)
get_settings().pr_description_only_description_prompts.user,
)
files_walkthrough = "\n".join(file_description_str_list)
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
MAX_EXTRA_FILES_TO_PROMPT = 50
if remaining_files_list:
files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:"
files_walkthrough_prompt += (
"\n\nNo more token budget. Additional unprocessed files:"
)
for i, file in enumerate(remaining_files_list):
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
get_logger().debug(
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
)
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break
if deleted_files_list:
@ -286,32 +407,57 @@ class PRDescription:
for i, file in enumerate(deleted_files_list):
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
get_logger().debug(
f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
)
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break
tokens_files_walkthrough = len(
token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt))
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
token_handler_only_description_prompt.encoder.encode(
files_walkthrough_prompt
)
)
total_tokens = (
token_handler_only_description_prompt.prompt_tokens
+ tokens_files_walkthrough
)
max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit
files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt,
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
num_input_tokens=tokens_files_walkthrough)
files_walkthrough_prompt = clip_tokens(
files_walkthrough_prompt,
max_tokens_model
- OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
- token_handler_only_description_prompt.prompt_tokens,
num_input_tokens=tokens_files_walkthrough,
)
# PR header inference
get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt)
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt,
prompt="pr_description_only_description_prompts")
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
get_logger().debug(
f"PR diff only description", artifact=files_walkthrough_prompt
)
prediction_headers = await self._get_prediction(
model,
patches_diff=files_walkthrough_prompt,
prompt="pr_description_only_description_prompts",
)
prediction_headers = (
prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
)
# extend the tables with the files not shown
files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough)
files_walkthrough_extended = await self.extend_uncovered_files(
files_walkthrough
)
# final processing
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
self.prediction = (
prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
)
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
get_logger().error(
f"Error getting valid YAML in large PR handling for describe {self.pr_id}"
)
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
@ -321,12 +467,17 @@ class PRDescription:
prediction = original_prediction
# get the original prediction filenames
original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix)
original_prediction_loaded = load_yaml(
original_prediction, keys_fix_yaml=self.keys_fix
)
if isinstance(original_prediction_loaded, list):
original_prediction_dict = {"pr_files": original_prediction_loaded}
else:
original_prediction_dict = original_prediction_loaded
filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])]
filenames_predicted = [
file['filename'].strip()
for file in original_prediction_dict.get('pr_files', [])
]
# extend the prediction with additional files not included in the original prediction
pr_files = self.git_provider.get_diff_files()
@ -349,7 +500,9 @@ class PRDescription:
additional files
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}")
get_logger().debug(
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}"
)
break
extra_file_yaml = f"""\
@ -364,10 +517,18 @@ class PRDescription:
# merge the two dictionaries
if counter_extra_files > 0:
get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction")
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
get_logger().info(
f"Adding {counter_extra_files} unprocessed extra files to table prediction"
)
prediction_extra_dict = load_yaml(
prediction_extra, keys_fix_yaml=self.keys_fix
)
if isinstance(original_prediction_dict, dict) and isinstance(
prediction_extra_dict, dict
):
original_prediction_dict["pr_files"].extend(
prediction_extra_dict["pr_files"]
)
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml
@ -379,11 +540,12 @@ class PRDescription:
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
return original_prediction
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix)
original_prediction_dict = load_yaml(
self.prediction, keys_fix_yaml=self.keys_fix
)
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
@ -397,10 +559,16 @@ class PRDescription:
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
prediction_extra_dict = load_yaml(
prediction_extra, keys_fix_yaml=self.keys_fix
)
# merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
if isinstance(original_prediction_dict, dict) and isinstance(
prediction_extra_dict, dict
):
original_prediction_dict["pr_files"].extend(
prediction_extra_dict["pr_files"]
)
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml
@ -409,7 +577,9 @@ class PRDescription:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
async def _get_prediction(
self, model: str, patches_diff: str, prompt="pr_description_prompt"
) -> str:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
@ -417,14 +587,18 @@ class PRDescription:
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables)
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables)
system_prompt = environment.from_string(
get_settings().get(prompt, {}).get("system", "")
).render(self.variables)
user_prompt = environment.from_string(
get_settings().get(prompt, {}).get("user", "")
).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt
user=user_prompt,
)
return response
@ -433,7 +607,10 @@ class PRDescription:
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
if get_settings().pr_description.add_original_user_description and self.user_description:
if (
get_settings().pr_description.add_original_user_description
and self.user_description
):
self.data["User Description"] = self.user_description
# re-order keys
@ -459,7 +636,11 @@ class PRDescription:
pr_labels = self.data['labels']
elif type(self.data['labels']) == str:
pr_labels = self.data['labels'].split(',')
elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels:
elif (
'type' in self.data
and self.data['type']
and get_settings().pr_description.publish_labels
):
if type(self.data['type']) == list:
pr_labels = self.data['type']
elif type(self.data['type']) == str:
@ -474,7 +655,9 @@ class PRDescription:
if label_i in d:
pr_labels[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
get_logger().error(
f"Error converting labels to original case {self.pr_id}: {e}"
)
return pr_labels
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
@ -482,13 +665,13 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"])
if (not get_settings().pr_description.generate_ai_title):
if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
# Assign the value of the 'PR Title' key to 'title' variable
title = ai_title
body = self.user_description
if get_settings().pr_description.include_generated_by_header:
ai_header = f"### 🤖 Generated by PR Agent at {self.git_provider.last_commit_id.sha}\n\n"
@ -514,8 +697,9 @@ class PRDescription:
pr_file_changes = []
if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body):
try:
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm,
self.file_label_dict)
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
walkthrough_gfm, self.file_label_dict
)
body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
except Exception as e:
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
@ -545,7 +729,7 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"])
if (not get_settings().pr_description.generate_ai_title):
if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
@ -575,13 +759,20 @@ class PRDescription:
pr_body += f'- `{filename}`: {description}\n'
if self.git_provider.is_supported("gfm_markdown"):
pr_body += "</details>\n"
elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types:
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value)
elif (
'pr_files' in key.lower()
and get_settings().pr_description.enable_semantic_files_types
):
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(
changes_walkthrough, value
)
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
elif key.lower().strip() == 'description':
if isinstance(value, list):
value = ', '.join(v.rstrip() for v in value)
value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space
value = value.replace(
'\n-', '\n\n-'
).strip() # makes the bullet points more readable by adding double space
pr_body += f"{value}\n"
else:
# if the value is a list, join its items by comma
@ -591,24 +782,37 @@ class PRDescription:
if idx < len(self.data) - 1:
pr_body += "\n\n___\n\n"
return title, pr_body, changes_walkthrough, pr_file_changes,
return (
title,
pr_body,
changes_walkthrough,
pr_file_changes,
)
def _prepare_file_labels(self):
file_label_dict = {}
if (not self.data or not isinstance(self.data, dict) or
'pr_files' not in self.data or not self.data['pr_files']):
if (
not self.data
or not isinstance(self.data, dict)
or 'pr_files' not in self.data
or not self.data['pr_files']
):
return file_label_dict
for file in self.data['pr_files']:
try:
required_fields = ['changes_title', 'filename', 'label']
if not all(field in file for field in required_fields):
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file",
artifact={"file": file})
get_logger().warning(
f"Missing required fields in file label dict {self.pr_id}, skipping file",
artifact={"file": file},
)
continue
if not file.get('changes_title'):
get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
artifact={"file": file})
get_logger().warning(
f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
artifact={"file": file},
)
continue
filename = file['filename'].replace("'", "`").replace('"', '`')
changes_summary = file.get('changes_summary', "").strip()
@ -616,7 +820,9 @@ class PRDescription:
label = file.get('label').strip().lower()
if label not in file_label_dict:
file_label_dict[label] = []
file_label_dict[label].append((filename, changes_title, changes_summary))
file_label_dict[label].append(
(filename, changes_title, changes_summary)
)
except Exception as e:
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
pass
@ -640,7 +846,9 @@ class PRDescription:
header = f"相关文件"
delta = 75
# header += "&nbsp; " * delta
pr_body += f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
pr_body += (
f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
)
pr_body += """<tbody>"""
for semantic_label in value.keys():
s_label = semantic_label.strip("'").strip('"')
@ -651,14 +859,22 @@ class PRDescription:
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
else:
pr_body += f"""<td><table>"""
for filename, file_changes_title, file_change_description in list_tuples:
for (
filename,
file_changes_title,
file_change_description,
) in list_tuples:
filename = filename.replace("'", "`").rstrip()
filename_publish = filename.split("/")[-1]
if file_changes_title and file_changes_title.strip() != "...":
file_changes_title_code = f"<code>{file_changes_title}</code>"
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
file_changes_title_code_br = insert_br_after_x_chars(
file_changes_title_code, x=(delta - 5)
).strip()
if len(file_changes_title_code_br) < (delta - 5):
file_changes_title_code_br += "&nbsp; " * ((delta - 5) - len(file_changes_title_code_br))
file_changes_title_code_br += "&nbsp; " * (
(delta - 5) - len(file_changes_title_code_br)
)
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
else:
filename_publish = f"<strong>{filename_publish}</strong>"
@ -679,15 +895,30 @@ class PRDescription:
link = ""
if hasattr(self.git_provider, 'get_line_link'):
filename = filename.strip()
link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()):
get_logger().warning(f"Error getting line link for '{filename}'")
link = self.git_provider.get_line_link(
filename, relevant_line_start=-1
)
if (not link or not diff_plus_minus) and (
'additional files' not in filename.lower()
):
get_logger().warning(
f"Error getting line link for '{filename}'"
)
continue
# Add file data to the PR body
file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5))
pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename,
filename_publish, link, pr_body)
file_change_description_br = insert_br_after_x_chars(
file_change_description, x=(delta - 5)
)
pr_body = self.add_file_data(
delta_nbsp,
diff_plus_minus,
file_change_description_br,
filename,
filename_publish,
link,
pr_body,
)
# Close the collapsible file list
if use_collapsible_file_list:
@ -697,13 +928,22 @@ class PRDescription:
pr_body += """</tr></tbody></table>"""
except Exception as e:
get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}")
get_logger().error(
f"Error processing pr files to markdown {self.pr_id}: {str(e)}"
)
pass
return pr_body, pr_comments
def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link,
pr_body) -> str:
def add_file_data(
self,
delta_nbsp,
diff_plus_minus,
file_change_description_br,
filename,
filename_publish,
link,
pr_body,
) -> str:
if not file_change_description_br:
pr_body += f"""
<tr>
@ -735,6 +975,7 @@ class PRDescription:
"""
return pr_body
def count_chars_without_html(string):
if '<' not in string:
return len(string)

View File

@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger
class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
@ -93,7 +97,9 @@ class PRGenerateLabels:
elif pr_labels:
value = ', '.join(v for v in pr_labels)
pr_labels_text = f"## PR Labels:\n{value}\n"
self.git_provider.publish_comment(pr_labels_text, is_temporary=False)
self.git_provider.publish_comment(
pr_labels_text, is_temporary=False
)
self.git_provider.remove_initial_comment()
except Exception as e:
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
@ -137,14 +143,18 @@ class PRGenerateLabels:
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables)
system_prompt = environment.from_string(
get_settings().pr_custom_labels_prompt.system
).render(self.variables)
user_prompt = environment.from_string(
get_settings().pr_custom_labels_prompt.user
).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt
user=user_prompt,
)
return response
@ -153,8 +163,6 @@ class PRGenerateLabels:
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip())
def _prepare_labels(self) -> List[str]:
pr_types = []
@ -174,6 +182,8 @@ class PRGenerateLabels:
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
get_logger().error(
f"Error converting labels to original case {self.pr_id}: {e}"
)
return pr_types

View File

@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
from utils.pr_agent.git_providers import (
BitbucketServerProvider,
GithubProvider,
get_git_provider_with_context,
)
from utils.pr_agent.log import get_logger
@ -29,31 +33,50 @@ def extract_header(snippet):
res = f"#{highest_header.lower().replace(' ', '-')}"
return res
class PRHelpMessage:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
return_as_string=False,
):
self.git_provider = get_git_provider_with_context(pr_url)
self.ai_handler = ai_handler()
self.question_str = self.parse_args(args)
self.return_as_string = return_as_string
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
self.num_retrieved_snippets = get_settings().get(
'pr_help.num_retrieved_snippets', 5
)
if self.question_str:
self.vars = {
"question": self.question_str,
"snippets": "",
}
self.token_handler = TokenHandler(None,
self.vars,
get_settings().pr_help_prompts.system,
get_settings().pr_help_prompts.user)
self.token_handler = TokenHandler(
None,
self.vars,
get_settings().pr_help_prompts.system,
get_settings().pr_help_prompts.user,
)
async def _prepare_prediction(self, model: str):
try:
variables = copy.deepcopy(self.vars)
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_help_prompts.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_help_prompts.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response
except Exception as e:
get_logger().error(f"Error while preparing prediction: {e}")
@ -81,7 +104,7 @@ class PRHelpMessage:
'.': '',
'?': '',
'!': '',
' ': '-'
' ': '-',
}
# Compile regex pattern for characters to remove
@ -90,37 +113,69 @@ class PRHelpMessage:
# Perform replacements in a single pass and convert to lowercase
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
except Exception:
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
get_logger().exception(
f"Error while formatting markdown header", artifacts={'header': header}
)
return ""
async def run(self):
try:
if self.question_str:
get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ')
get_logger().info(
f'Answering a PR question about the PR {self.git_provider.pr_url} '
)
if not get_settings().get('openai.key'):
if get_settings().config.publish_output:
self.git_provider.publish_comment(
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
)
else:
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
get_logger().error(
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
)
return
# current path
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
docs_path = Path(__file__).parent.parent.parent / 'docs' / 'docs'
# get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude = ['/finetuning_benchmark/']
files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'}
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
files_to_exclude = {
'EXAMPLE_BEST_PRACTICE.md',
'compression_strategy.md',
'/docs/overview/index.md',
}
md_files = [
file
for file in md_files
if not any(folder in str(file) for folder in folders_to_exclude)
and not any(
file.name == file_to_exclude
for file_to_exclude in files_to_exclude
)
]
# sort the 'md_files' so that 'priority_files' will be at the top
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
'tools/improve.md', '/faq']
md_files_priority = [file for file in md_files if
any(priority_string in str(file) for priority_string in priority_files_strings)]
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
priority_files_strings = [
'/docs/index.md',
'/usage-guide',
'tools/describe.md',
'tools/review.md',
'tools/improve.md',
'/faq',
]
md_files_priority = [
file
for file in md_files
if any(
priority_string in str(file)
for priority_string in priority_files_strings
)
]
md_files_not_priority = [
file for file in md_files if file not in md_files_priority
]
md_files = md_files_priority + md_files_not_priority
docs_prompt = ""
@ -132,24 +187,36 @@ class PRHelpMessage:
except Exception as e:
get_logger().error(f"Error while reading the file {file}: {e}")
token_count = self.token_handler.count_tokens(docs_prompt)
get_logger().debug(f"Token count of full documentation website: {token_count}")
get_logger().debug(
f"Token count of full documentation website: {token_count}"
)
model = get_settings().config.model
if model in MAX_TOKENS:
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
max_tokens_full = MAX_TOKENS[
model
] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else:
max_tokens_full = get_max_tokens(model)
delta_output = 2000
if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
get_logger().info(
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message."
)
docs_prompt = clip_tokens(
docs_prompt, max_tokens_full - delta_output
)
self.vars['snippets'] = docs_prompt.strip()
# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response = await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.REGULAR
)
response_yaml = load_yaml(response)
if isinstance(response_yaml, str):
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
get_logger().warning(
f"failing to parse response: {response_yaml}, publishing the response as is"
)
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
@ -160,7 +227,9 @@ class PRHelpMessage:
relevant_sections = response_yaml.get('relevant_sections')
if not relevant_sections:
get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
get_logger().info(
f"Could not find relevant answer for the question: {self.question_str}"
)
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
@ -178,29 +247,38 @@ class PRHelpMessage:
for section in relevant_sections:
file = section.get('file_name').strip().removesuffix('.md')
if str(section['relevant_section_header_string']).strip():
markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
markdown_header = self.format_markdown_header(
section['relevant_section_header_string']
)
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
else:
answer_str += f"> - {base_path}{file}\n"
# publish the answer
if get_settings().config.publish_output:
self.git_provider.publish_comment(answer_str)
else:
get_logger().info(f"Answer:\n{answer_str}")
else:
if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"):
if not isinstance(
self.git_provider, BitbucketServerProvider
) and not self.git_provider.is_supported("gfm_markdown"):
self.git_provider.publish_comment(
"The `Help` tool requires gfm markdown, which is not supported by your code platform.")
"The `Help` tool requires gfm markdown, which is not supported by your code platform."
)
return
get_logger().info('Getting PR Help Message...')
relevant_configs = {'pr_help': dict(get_settings().pr_help),
'config': dict(get_settings().config)}
relevant_configs = {
'pr_help': dict(get_settings().pr_help),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
pr_comment = "## PR Agent Walkthrough 🤖\n\n"
pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."""
pr_comment += (
"Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."
""
)
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
base_path = "https://pr-agent-docs.codium.ai/tools"
@ -211,32 +289,58 @@ class PRHelpMessage:
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
tool_names.append(f"[TEST]({base_path}/test/) 💎")
tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎")
tool_names.append(
f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎"
)
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
tool_names.append(f"[ASK]({base_path}/ask/)")
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎")
tool_names.append(
f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎"
)
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
descriptions = []
descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels")
descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more")
descriptions.append(
"Generates PR description - title, type, summary, code walkthrough and labels"
)
descriptions.append(
"Adjustable feedback about the PR, possible issues, security concerns, review effort and more"
)
descriptions.append("Code suggestions for improving the PR")
descriptions.append("Automatically updates the changelog")
descriptions.append("Generates documentation to methods/functions/classes that changed in the PR")
descriptions.append("Generates unit tests for a specific component, based on the PR code change")
descriptions.append("Code suggestions for a specific component that changed in the PR")
descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component")
descriptions.append(
"Generates documentation to methods/functions/classes that changed in the PR"
)
descriptions.append(
"Generates unit tests for a specific component, based on the PR code change"
)
descriptions.append(
"Code suggestions for a specific component that changed in the PR"
)
descriptions.append(
"Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component"
)
descriptions.append("Answering free-text questions about the PR")
descriptions.append("Automatically retrieves and presents similar issues")
descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user")
descriptions.append("Generates feedback and analysis for a failed CI job")
descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user")
descriptions.append("Generates implementation code from review suggestions")
descriptions.append(
"Automatically retrieves and presents similar issues"
)
descriptions.append(
"Generates custom labels for the PR, based on specific guidelines defined by the user"
)
descriptions.append(
"Generates feedback and analysis for a failed CI job"
)
descriptions.append(
"Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user"
)
descriptions.append(
"Generates implementation code from review suggestions"
)
commands =[]
commands = []
commands.append("`/describe`")
commands.append("`/review`")
commands.append("`/improve`")
@ -271,7 +375,9 @@ class PRHelpMessage:
checkbox_list.append("[*]")
checkbox_list.append("[*]")
if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False):
if isinstance(
self.git_provider, GithubProvider
) and not get_settings().config.get('disable_checkboxes', False):
pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>"
for i in range(len(tool_names)):
pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>"

View File

@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.git_patch_processing import (
extract_hunk_lines_from_patch)
from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType
@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger
class PR_LineQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
@ -34,10 +38,12 @@ class PR_LineQuestions:
"full_hunk": "",
"selected_lines": "",
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_line_questions_prompt.system,
get_settings().pr_line_questions_prompt.user)
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_line_questions_prompt.system,
get_settings().pr_line_questions_prompt.user,
)
self.patches_diff = None
self.prediction = None
@ -48,7 +54,6 @@ class PR_LineQuestions:
question_str = ""
return question_str
async def run(self):
get_logger().info('Answering a PR lines question...')
# if get_settings().config.publish_output:
@ -62,22 +67,27 @@ class PR_LineQuestions:
file_name = get_settings().get('file_name', '')
comment_id = get_settings().get('comment_id', '')
if ask_diff:
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff,
file_name,
line_start=line_start,
line_end=line_end,
side=side
)
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(
ask_diff, file_name, line_start=line_start, line_end=line_end, side=side
)
else:
diff_files = self.git_provider.get_diff_files()
for file in diff_files:
if file.filename == file_name:
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename,
line_start=line_start,
line_end=line_end,
side=side)
(
self.patch_with_lines,
self.selected_lines,
) = extract_hunk_lines_from_patch(
file.patch,
file.filename,
line_start=line_start,
line_end=line_end,
side=side,
)
if self.patch_with_lines:
model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)
model_answer = await retry_with_fallback_models(
self._get_prediction, model_type=ModelType.WEAK
)
# sanitize the answer so that no line will start with "/"
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
if model_answer_sanitized.startswith("/"):
@ -85,7 +95,9 @@ class PR_LineQuestions:
get_logger().info('Preparing answer...')
if comment_id:
self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized)
self.git_provider.reply_to_comment_from_comment_id(
comment_id, model_answer_sanitized
)
else:
self.git_provider.publish_comment(model_answer_sanitized)
@ -96,8 +108,12 @@ class PR_LineQuestions:
variables["full_hunk"] = self.patch_with_lines # update diff
variables["selected_lines"] = self.selected_lines
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_line_questions_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_line_questions_prompt.user
).render(variables)
if get_settings().config.verbosity_level >= 2:
# get_logger().info(f"\nSystem prompt:\n{system_prompt}")
# get_logger().info(f"\nUser prompt:\n{user_prompt}")
@ -105,5 +121,9 @@ class PR_LineQuestions:
print(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response

View File

@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage
class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
question_str = self.parse_args(args)
self.pr_url = pr_url
self.git_provider = get_git_provider()(pr_url)
@ -36,10 +41,12 @@ class PRQuestions:
"questions": self.question_str,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user)
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user,
)
self.patches_diff = None
self.prediction = None
@ -52,8 +59,10 @@ class PRQuestions:
async def run(self):
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
relevant_configs = {'pr_questions': dict(get_settings().pr_questions),
'config': dict(get_settings().config)}
relevant_configs = {
'pr_questions': dict(get_settings().pr_questions),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output:
self.git_provider.publish_comment("思考回答中...", is_temporary=True)
@ -63,12 +72,17 @@ class PRQuestions:
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.WEAK
)
pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment)
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text:
if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_questions.enable_help_text
):
pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
pr_comment += HelpMessage.get_ask_usage_guide()
pr_comment += "\n</details>\n"
@ -85,7 +99,9 @@ class PRQuestions:
# /ask question ... > ![image](img_path)
img_path = self.question_str.split('![image]')[1].strip().strip('()')
self.vars['img_path'] = img_path
elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link
elif 'https://' in self.question_str and (
'.png' in self.question_str or 'jpg' in self.question_str
): # direct image link
# include https:// in the image path
img_path = 'https://' + self.question_str.split('https://')[1]
self.vars['img_path'] = img_path
@ -104,16 +120,28 @@ class PRQuestions:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_questions_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_questions_prompt.user
).render(variables)
if 'img_path' in variables:
img_path = self.vars['img_path']
response, finish_reason = await (self.ai_handler.chat_completion
(model=model, temperature=get_settings().config.temperature,
system=system_prompt, user=user_prompt, img_path=img_path))
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
img_path=img_path,
)
else:
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response
def _prepare_pr_answer(self) -> str:
@ -123,9 +151,13 @@ class PRQuestions:
if model_answer_sanitized.startswith("/"):
model_answer_sanitized = " " + model_answer_sanitized
if model_answer_sanitized != model_answer:
get_logger().debug(f"Sanitized model answer",
artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized})
get_logger().debug(
f"Sanitized model answer",
artifact={
"model_answer": model_answer,
"sanitized_answer": model_answer_sanitized,
},
)
answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"

View File

@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
get_pr_diff,
retry_with_fallback_models)
from utils.pr_agent.algo.pr_processing import (
add_ai_metadata_to_diff_files,
get_pr_diff,
retry_with_fallback_models,
)
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader,
convert_to_markdown_v2, github_action_output,
load_yaml, show_relevant_configurations)
from utils.pr_agent.algo.utils import (
ModelType,
PRReviewHeader,
convert_to_markdown_v2,
github_action_output,
load_yaml,
show_relevant_configurations,
)
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (get_git_provider_with_context)
from utils.pr_agent.git_providers.git_provider import (IncrementalPR,
get_main_pr_language)
from utils.pr_agent.git_providers import get_git_provider_with_context
from utils.pr_agent.git_providers.git_provider import (
IncrementalPR,
get_main_pr_language,
)
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
from utils.pr_agent.tools.ticket_pr_compliance_check import (
extract_and_cache_pr_tickets)
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
class PRReviewer:
@ -29,8 +37,14 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
is_answer: bool = False,
is_auto: bool = False,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@ -55,16 +69,23 @@ class PRReviewer:
self.is_auto = is_auto
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
raise Exception(
f"Answer mode is not supported for {get_settings().config.git_provider} for now"
)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
answer_str, question_str = self._get_user_answers()
self.pr_description, self.pr_description_files = (
self.git_provider.get_pr_description(split_changes_walkthrough=True))
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
get_settings().get("config.enable_ai_metadata", False)):
(
self.pr_description,
self.pr_description_files,
) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
if (
self.pr_description_files
and get_settings().get("config.is_auto_command", False)
and get_settings().get("config.enable_ai_metadata", False)
):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command")
else:
@ -89,9 +110,11 @@ class PRReviewer:
"commit_messages_str": self.git_provider.get_commit_messages(),
"custom_labels": "",
"enable_custom_labels": get_settings().config.enable_custom_labels,
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"related_tickets": get_settings().get('related_tickets', []),
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
'duplicate_prompt_examples': get_settings().config.get(
'duplicate_prompt_examples', False
),
"date": datetime.datetime.now().strftime('%Y-%m-%d'),
}
@ -99,7 +122,7 @@ class PRReviewer:
self.git_provider.pr,
self.vars,
get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user
get_settings().pr_review_prompt.user,
)
def parse_incremental(self, args: List[str]):
@ -117,7 +140,10 @@ class PRReviewer:
get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
return None
if self.incremental.is_incremental and not self._can_run_incremental_review():
if (
self.incremental.is_incremental
and not self._can_run_incremental_review()
):
return None
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
@ -126,27 +152,41 @@ class PRReviewer:
# return None
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer),
'config': dict(get_settings().config)}
relevant_configs = {
'pr_reviewer': dict(get_settings().pr_reviewer),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set:
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files")
if (
self.incremental.is_incremental
and hasattr(self.git_provider, "unreviewed_files_set")
and not self.git_provider.unreviewed_files_set
):
get_logger().info(
f"Incremental review is enabled for {self.pr_url} but there are no new files"
)
previous_review_url = ""
if hasattr(self.git_provider, "previous_review"):
previous_review_url = self.git_provider.previous_review.html_url
if get_settings().config.publish_output:
self.git_provider.publish_comment(f"Incremental Review Skipped\n"
f"No files were changed since the [previous PR Review]({previous_review_url})")
self.git_provider.publish_comment(
f"Incremental Review Skipped\n"
f"No files were changed since the [previous PR Review]({previous_review_url})"
)
return None
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
if get_settings().config.publish_output and not get_settings().config.get(
'is_auto_command', False
):
self.git_provider.publish_comment("准备评审中...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.REGULAR
)
if not self.prediction:
self.git_provider.remove_initial_comment()
return None
@ -156,12 +196,19 @@ class PRReviewer:
if get_settings().config.publish_output:
# publish the review
if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental:
final_update_message = get_settings().pr_reviewer.final_update_message
self.git_provider.publish_persistent_comment(pr_review,
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
update_header=True,
final_update_message=final_update_message, )
if (
get_settings().pr_reviewer.persistent_comment
and not self.incremental.is_incremental
):
final_update_message = (
get_settings().pr_reviewer.final_update_message
)
self.git_provider.publish_persistent_comment(
pr_review,
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
update_header=True,
final_update_message=final_update_message,
)
else:
self.git_provider.publish_comment(pr_review)
@ -174,11 +221,13 @@ class PRReviewer:
get_logger().error(f"Failed to review PR: {e}")
async def _prepare_prediction(self, model: str) -> None:
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False,)
self.patches_diff = get_pr_diff(
self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False,
)
if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff)
@ -201,14 +250,18 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_review_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_review_prompt.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt
user=user_prompt,
)
return response
@ -220,10 +273,20 @@ class PRReviewer:
"""
first_key = 'review'
last_key = 'security_concerns'
data = load_yaml(self.prediction.strip(),
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
"relevant_file:", "relevant_line:", "suggestion:"],
first_key=first_key, last_key=last_key)
data = load_yaml(
self.prediction.strip(),
keys_fix_yaml=[
"ticket_compliance_check",
"estimated_effort_to_review_[1-5]:",
"security_concerns:",
"key_issues_to_review:",
"relevant_file:",
"relevant_line:",
"suggestion:",
],
first_key=first_key,
last_key=last_key,
)
github_action_output(data, 'review')
# move data['review'] 'key_issues_to_review' key to the end of the dictionary
@ -234,24 +297,38 @@ class PRReviewer:
incremental_review_markdown_text = None
# Add incremental review section
if self.incremental.is_incremental:
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
f"{self.git_provider.incremental.first_new_commit_sha}"
last_commit_url = (
f"{self.git_provider.get_pr_url()}/commits/"
f"{self.git_provider.incremental.first_new_commit_sha}"
)
incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"),
incremental_review_markdown_text,
git_provider=self.git_provider,
files=self.git_provider.get_diff_files())
markdown_text = convert_to_markdown_v2(
data,
self.git_provider.is_supported("gfm_markdown"),
incremental_review_markdown_text,
git_provider=self.git_provider,
files=self.git_provider.get_diff_files(),
)
# Add help text if gfm_markdown is supported
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text:
if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_reviewer.enable_help_text
):
markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
markdown_text += HelpMessage.get_review_usage_guide()
markdown_text += "\n</details>\n"
# Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False):
markdown_text += show_relevant_configurations(relevant_section='pr_reviewer')
if (
get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
markdown_text += show_relevant_configurations(
relevant_section='pr_reviewer'
)
# Add custom labels from the review prediction (effort, security)
self.set_review_labels(data)
@ -306,34 +383,50 @@ class PRReviewer:
if comment:
self.git_provider.remove_comment(comment)
except Exception as e:
get_logger().exception(f"Failed to remove previous review comment, error: {e}")
get_logger().exception(
f"Failed to remove previous review comment, error: {e}"
)
def _can_run_incremental_review(self) -> bool:
"""Checks if we can run incremental review according the various configurations and previous review"""
# checking if running is auto mode but there are no new commits
if self.is_auto and not self.incremental.first_new_commit_sha:
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits")
get_logger().info(
f"Incremental review is enabled for {self.pr_url} but there are no new commits"
)
return False
if not hasattr(self.git_provider, "get_incremental_commits"):
get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}")
get_logger().info(
f"Incremental review is not supported for {get_settings().config.git_provider}"
)
return False
# checking if there are enough commits to start the review
num_new_commits = len(self.incremental.commits_range)
num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review
num_commits_threshold = (
get_settings().pr_reviewer.minimal_commits_for_incremental_review
)
not_enough_commits = num_new_commits < num_commits_threshold
# checking if the commits are not too recent to start the review
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
)
last_seen_commit_date = (
self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None
self.incremental.last_seen_commit.commit.author.date
if self.incremental.last_seen_commit
else None
)
all_commits_too_recent = (
last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False
last_seen_commit_date > recent_commits_threshold
if self.incremental.last_seen_commit
else False
)
# check all the thresholds or just one to start the review
condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all
condition = (
any
if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review
else all
)
if condition((not_enough_commits, all_commits_too_recent)):
get_logger().info(
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
@ -348,31 +441,55 @@ class PRReviewer:
return
if not get_settings().pr_reviewer.require_estimate_effort_to_review:
get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output
get_settings().pr_reviewer.enable_review_labels_effort = (
False # we did not generate this output
)
if not get_settings().pr_reviewer.require_security_review:
get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output
get_settings().pr_reviewer.enable_review_labels_security = (
False # we did not generate this output
)
if (get_settings().pr_reviewer.enable_review_labels_security or
get_settings().pr_reviewer.enable_review_labels_effort):
if (
get_settings().pr_reviewer.enable_review_labels_security
or get_settings().pr_reviewer.enable_review_labels_effort
):
try:
review_labels = []
if get_settings().pr_reviewer.enable_review_labels_effort:
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
estimated_effort = data['review'][
'estimated_effort_to_review_[1-5]'
]
estimated_effort_number = 0
if isinstance(estimated_effort, str):
try:
estimated_effort_number = int(estimated_effort.split(',')[0])
estimated_effort_number = int(
estimated_effort.split(',')[0]
)
except ValueError:
get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}")
get_logger().warning(
f"Invalid estimated_effort value: {estimated_effort}"
)
elif isinstance(estimated_effort, int):
estimated_effort_number = estimated_effort
else:
get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}")
get_logger().warning(
f"Unexpected type for estimated_effort: {type(estimated_effort)}"
)
if 1 <= estimated_effort_number <= 5: # 1, because ...
review_labels.append(f'Review effort {estimated_effort_number}/5')
if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review:
security_concerns = data['review']['security_concerns'] # yes, because ...
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
review_labels.append(
f'Review effort {estimated_effort_number}/5'
)
if (
get_settings().pr_reviewer.enable_review_labels_security
and get_settings().pr_reviewer.require_security_review
):
security_concerns = data['review'][
'security_concerns'
] # yes, because ...
security_concerns_bool = (
'yes' in security_concerns.lower()
or 'true' in security_concerns.lower()
)
if security_concerns_bool:
review_labels.append('Possible security concern')
@ -381,17 +498,26 @@ class PRReviewer:
current_labels = []
get_logger().debug(f"Current labels:\n{current_labels}")
if current_labels:
current_labels_filtered = [label for label in current_labels if
not label.lower().startswith('review effort') and not label.lower().startswith(
'possible security concern')]
current_labels_filtered = [
label
for label in current_labels
if not label.lower().startswith('review effort')
and not label.lower().startswith('possible security concern')
]
else:
current_labels_filtered = []
new_labels = review_labels + current_labels_filtered
if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels):
get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}")
if (current_labels or review_labels) and sorted(new_labels) != sorted(
current_labels
):
get_logger().info(
f"Setting review labels:\n{review_labels + current_labels_filtered}"
)
self.git_provider.publish_labels(new_labels)
else:
get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}")
get_logger().info(
f"Review labels are already set:\n{review_labels + current_labels_filtered}"
)
except Exception as e:
get_logger().error(f"Failed to set review labels, error: {e}")
@ -406,5 +532,7 @@ class PRReviewer:
self.git_provider.publish_comment("自动批准 PR")
else:
get_logger().info("Auto-approval option is disabled")
self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. "
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
self.git_provider.publish_comment(
"PR-Agent 的自动批准选项已禁用. "
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)"
)

View File

@ -24,12 +24,16 @@ class PRSimilarIssue:
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
self.issue_url = issue_url
self.git_provider = get_git_provider()()
repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
repo_name, issue_number = self.git_provider._parse_issue_url(
issue_url.split('=')[-1]
)
self.git_provider.repo = repo_name
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
self.token_handler = TokenHandler()
repo_obj = self.git_provider.repo_obj
repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
repo_name_for_index = self.repo_name_for_index = (
repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
)
index_name = self.index_name = "codium-ai-pr-agent-issues"
if get_settings().pr_similar_issue.vectordb == "pinecone":
@ -38,17 +42,30 @@ class PRSimilarIssue:
import pinecone
from pinecone_datasets import Dataset, DatasetMetadata
except:
raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb")
raise Exception(
"Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb"
)
# assuming pinecone api key and environment are set in secrets file
try:
api_key = get_settings().pinecone.api_key
environment = get_settings().pinecone.environment
except Exception:
if not self.cli_mode:
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
issue_main.create_comment("Please set pinecone api key and environment in secrets file")
raise Exception("Please set pinecone api key and environment in secrets file")
(
repo_name,
original_issue_number,
) = self.git_provider._parse_issue_url(
self.issue_url.split('=')[-1]
)
issue_main = self.git_provider.repo_obj.get_issue(
original_issue_number
)
issue_main.create_comment(
"Please set pinecone api key and environment in secrets file"
)
raise Exception(
"Please set pinecone api key and environment in secrets file"
)
# check if index exists, and if repo is already indexed
run_from_scratch = False
@ -69,7 +86,9 @@ class PRSimilarIssue:
upsert = True
else:
pinecone_index = pinecone.Index(index_name=index_name)
res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict()
res = pinecone_index.fetch(
[f"example_issue_{repo_name_for_index}"]
).to_dict()
if res["vectors"]:
upsert = False
@ -79,7 +98,9 @@ class PRSimilarIssue:
get_logger().info('Getting issues...')
issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done')
self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert)
self._update_index_with_issues(
issues, repo_name_for_index, upsert=upsert
)
else: # update index if needed
pinecone_index = pinecone.Index(index_name=index_name)
issues_to_update = []
@ -105,7 +126,9 @@ class PRSimilarIssue:
if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...')
self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True)
self._update_index_with_issues(
issues_to_update, repo_name_for_index, upsert=True
)
else:
get_logger().info('No new issues to update')
@ -133,7 +156,12 @@ class PRSimilarIssue:
ingest = True
else:
self.table = self.db[index_name]
res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list()
res = (
self.table.search()
.limit(len(self.table))
.where(f"id='example_issue_{repo_name_for_index}'")
.to_list()
)
get_logger().info("result: ", res)
if res[0].get("vector"):
ingest = False
@ -145,7 +173,9 @@ class PRSimilarIssue:
issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done')
self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest)
self._update_table_with_issues(
issues, repo_name_for_index, ingest=ingest
)
else: # update table if needed
issues_to_update = []
issues_paginated_list = repo_obj.get_issues(state='all')
@ -156,7 +186,12 @@ class PRSimilarIssue:
issue_str, comments, number = self._process_issue(issue)
issue_key = f"issue_{number}"
issue_id = issue_key + "." + "issue"
res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list()
res = (
self.table.search()
.limit(len(self.table))
.where(f"id='{issue_id}'")
.to_list()
)
is_new_issue = True
for r in res:
if r['metadata']['repo'] == repo_name_for_index:
@ -170,14 +205,17 @@ class PRSimilarIssue:
if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...')
self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True)
self._update_table_with_issues(
issues_to_update, repo_name_for_index, ingest=True
)
else:
get_logger().info('No new issues to update')
async def run(self):
get_logger().info('Getting issue...')
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
repo_name, original_issue_number = self.git_provider._parse_issue_url(
self.issue_url.split('=')[-1]
)
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
issue_str, comments, number = self._process_issue(issue_main)
openai.api_key = get_settings().openai.key
@ -193,10 +231,12 @@ class PRSimilarIssue:
if get_settings().pr_similar_issue.vectordb == "pinecone":
pinecone_index = pinecone.Index(index_name=self.index_name)
res = pinecone_index.query(embeds[0],
top_k=5,
filter={"repo": self.repo_name_for_index},
include_metadata=True).to_dict()
res = pinecone_index.query(
embeds[0],
top_k=5,
filter={"repo": self.repo_name_for_index},
include_metadata=True,
).to_dict()
for r in res['matches']:
# skip example issue
@ -214,14 +254,20 @@ class PRSimilarIssue:
if issue_number not in relevant_issues_number_list:
relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]:
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
relevant_comment_number_list.append(
int(r["id"].split('.')[1].split('_')[-1])
)
else:
relevant_comment_number_list.append(-1)
score_list.append(str("{:.2f}".format(r['score'])))
get_logger().info('Done')
elif get_settings().pr_similar_issue.vectordb == "lancedb":
res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list()
res = (
self.table.search(embeds[0])
.where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True)
.to_list()
)
for r in res:
# skip example issue
@ -240,10 +286,12 @@ class PRSimilarIssue:
relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]:
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
relevant_comment_number_list.append(
int(r["id"].split('.')[1].split('_')[-1])
)
else:
relevant_comment_number_list.append(-1)
score_list.append(str("{:.2f}".format(1-r['_distance'])))
score_list.append(str("{:.2f}".format(1 - r['_distance'])))
get_logger().info('Done')
get_logger().info('Publishing response...')
@ -254,8 +302,12 @@ class PRSimilarIssue:
title = issue.title
url = issue.html_url
if relevant_comment_number_list[i] != -1:
url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url
similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
url = list(issue.get_comments())[
relevant_comment_number_list[i]
].html_url
similar_issues_str += (
f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
)
if get_settings().config.publish_output:
response = issue_main.create_comment(similar_issues_str)
get_logger().info(similar_issues_str)
@ -278,7 +330,7 @@ class PRSimilarIssue:
example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}",
text="example_issue",
metadata=Metadata(repo=repo_name_for_index)
metadata=Metadata(repo=repo_name_for_index),
)
corpus.append(example_issue_record)
@ -298,15 +350,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}"
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
if len(issue_str) < 8000 or self.token_handler.count_tokens(
issue_str
) < get_max_tokens(
MODEL
): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
metadata=Metadata(repo=repo_name_for_index,
username=username,
created_at=created_at,
level=IssueLevel.ISSUE)
metadata=Metadata(
repo=repo_name_for_index,
username=username,
created_at=created_at,
level=IssueLevel.ISSUE,
),
)
corpus.append(issue_record)
if comments:
@ -316,15 +373,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str):
continue
if len(comment_body) < 8000 or \
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
if (
len(comment_body) < 8000
or self.token_handler.count_tokens(comment_body)
< MAX_TOKENS[MODEL]
):
comment_record = Record(
id=issue_key + ".comment_" + str(j + 1),
text=comment_body,
metadata=Metadata(repo=repo_name_for_index,
username=username, # use issue username for all comments
created_at=created_at,
level=IssueLevel.COMMENT)
metadata=Metadata(
repo=repo_name_for_index,
username=username, # use issue username for all comments
created_at=created_at,
level=IssueLevel.COMMENT,
),
)
corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"])
@ -355,7 +417,9 @@ class PRSimilarIssue:
environment = get_settings().pinecone.environment
if not upsert:
get_logger().info('Creating index from scratch...')
ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment)
ds.to_pinecone_index(
self.index_name, api_key=api_key, environment=environment
)
time.sleep(15) # wait for pinecone to finalize indexing before querying
else:
get_logger().info('Upserting index...')
@ -374,7 +438,7 @@ class PRSimilarIssue:
example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}",
text="example_issue",
metadata=Metadata(repo=repo_name_for_index)
metadata=Metadata(repo=repo_name_for_index),
)
corpus.append(example_issue_record)
@ -394,15 +458,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}"
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
if len(issue_str) < 8000 or self.token_handler.count_tokens(
issue_str
) < get_max_tokens(
MODEL
): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
metadata=Metadata(repo=repo_name_for_index,
username=username,
created_at=created_at,
level=IssueLevel.ISSUE)
metadata=Metadata(
repo=repo_name_for_index,
username=username,
created_at=created_at,
level=IssueLevel.ISSUE,
),
)
corpus.append(issue_record)
if comments:
@ -412,15 +481,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str):
continue
if len(comment_body) < 8000 or \
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
if (
len(comment_body) < 8000
or self.token_handler.count_tokens(comment_body)
< MAX_TOKENS[MODEL]
):
comment_record = Record(
id=issue_key + ".comment_" + str(j + 1),
text=comment_body,
metadata=Metadata(repo=repo_name_for_index,
username=username, # use issue username for all comments
created_at=created_at,
level=IssueLevel.COMMENT)
metadata=Metadata(
repo=repo_name_for_index,
username=username, # use issue username for all comments
created_at=created_at,
level=IssueLevel.COMMENT,
),
)
corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"])
@ -446,7 +520,9 @@ class PRSimilarIssue:
if not ingest:
get_logger().info('Creating table from scratch...')
self.table = self.db.create_table(self.index_name, data=df, mode="overwrite")
self.table = self.db.create_table(
self.index_name, data=df, mode="overwrite"
)
time.sleep(15)
else:
get_logger().info('Ingesting in Table...')

View File

@ -20,13 +20,20 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
def __init__(
self,
pr_url: str,
cli_mode=False,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self.commit_changelog = (
get_settings().pr_update_changelog.push_changelog_changes
)
self._get_changelog_file() # self.changelog_file_str
self.ai_handler = ai_handler()
@ -47,15 +54,19 @@ class PRUpdateChangelog:
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user)
self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user,
)
async def run(self):
get_logger().info('Updating the changelog...')
relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog),
'config': dict(get_settings().config)}
relevant_configs = {
'pr_update_changelog': dict(get_settings().pr_update_changelog),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# currently only GitHub is supported for pushing changelog changes
@ -74,13 +85,21 @@ class PRUpdateChangelog:
if get_settings().config.publish_output:
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.WEAK
)
new_file_content, answer = self._prepare_changelog_update()
# Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False):
answer += show_relevant_configurations(relevant_section='pr_update_changelog')
if (
get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
answer += show_relevant_configurations(
relevant_section='pr_update_changelog'
)
get_logger().debug(f"PR output", artifact=answer)
@ -89,7 +108,9 @@ class PRUpdateChangelog:
if self.commit_changelog:
self._push_changelog_update(new_file_content, answer)
else:
self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}")
self.git_provider.publish_comment(
f"**Changelog updates:** 🔄\n\n{answer}"
)
async def _prepare_prediction(self, model: str):
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
@ -106,10 +127,18 @@ class PRUpdateChangelog:
if get_settings().pr_update_changelog.add_pr_link:
variables["pr_link"] = self.git_provider.get_pr_url()
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
system_prompt = environment.from_string(
get_settings().pr_update_changelog_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_update_changelog_prompt.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature)
model=model,
system=system_prompt,
user=user_prompt,
temperature=get_settings().config.temperature,
)
# post-process the response
response = response.strip()
@ -134,8 +163,10 @@ class PRUpdateChangelog:
new_file_content = answer
if not self.commit_changelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
answer += (
"\n\n\n>to commit the new content to the CHANGELOG.md file, please type:"
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
)
return new_file_content, answer
@ -163,8 +194,7 @@ class PRUpdateChangelog:
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
def _get_default_changelog(self):
example_changelog = \
"""
example_changelog = """
Example:
## <current_date>

View File

@ -7,14 +7,15 @@ from utils.pr_agent.log import get_logger
# Compile the regex pattern once, outside the function
GITHUB_TICKET_PATTERN = re.compile(
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
)
def find_jira_tickets(text):
# Regular expression patterns for JIRA tickets
patterns = [
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket
]
tickets = set()
@ -32,7 +33,9 @@ def find_jira_tickets(text):
return list(tickets)
def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'):
def extract_ticket_links_from_pr_description(
pr_description, repo_path, base_url_html='https://github.com'
):
"""
Extract all ticket links from PR description
"""
@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url
github_tickets.add(match[0])
elif match[1]: # Shorthand notation match: owner/repo#issue_number
owner, repo, issue_number = match[2], match[3], match[4]
github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}')
github_tickets.add(
f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}'
)
else: # #123 format
issue_number = match[5][1:] # remove #
if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}')
github_tickets.add(
f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}'
)
if len(github_tickets) > 3:
get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}")
get_logger().info(
f"Too many tickets found in PR description: {len(github_tickets)}"
)
# Limit the number of tickets to 3
github_tickets = set(list(github_tickets)[:3])
except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()},
)
return list(github_tickets)
@ -68,19 +79,26 @@ async def extract_tickets(git_provider):
try:
if isinstance(git_provider, GithubProvider):
user_description = git_provider.get_user_description()
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html)
tickets = extract_ticket_links_from_pr_description(
user_description, git_provider.repo, git_provider.base_url_html
)
tickets_content = []
if tickets:
for ticket in tickets:
repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
repo_name, original_issue_number = git_provider._parse_issue_url(
ticket
)
try:
issue_main = git_provider.repo_obj.get_issue(original_issue_number)
issue_main = git_provider.repo_obj.get_issue(
original_issue_number
)
except Exception as e:
get_logger().error(f"Error getting main issue: {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error getting main issue: {e}",
artifact={"traceback": traceback.format_exc()},
)
continue
issue_body_str = issue_main.body or ""
@ -93,47 +111,66 @@ async def extract_tickets(git_provider):
sub_issues = git_provider.fetch_sub_issues(ticket)
for sub_issue_url in sub_issues:
try:
sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url)
sub_issue = git_provider.repo_obj.get_issue(sub_issue_number)
(
sub_repo,
sub_issue_number,
) = git_provider._parse_issue_url(sub_issue_url)
sub_issue = git_provider.repo_obj.get_issue(
sub_issue_number
)
sub_body = sub_issue.body or ""
if len(sub_body) > MAX_TICKET_CHARACTERS:
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
sub_issues_content.append({
'ticket_url': sub_issue_url,
'title': sub_issue.title,
'body': sub_body
})
sub_issues_content.append(
{
'ticket_url': sub_issue_url,
'title': sub_issue.title,
'body': sub_body,
}
)
except Exception as e:
get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}")
get_logger().warning(
f"Failed to fetch sub-issue content for {sub_issue_url}: {e}"
)
except Exception as e:
get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}")
get_logger().warning(
f"Failed to fetch sub-issues for {ticket}: {e}"
)
# Extract labels
labels = []
try:
for label in issue_main.labels:
labels.append(label.name if hasattr(label, 'name') else label)
labels.append(
label.name if hasattr(label, 'name') else label
)
except Exception as e:
get_logger().error(f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()},
)
tickets_content.append({
'ticket_id': issue_main.number,
'ticket_url': ticket,
'title': issue_main.title,
'body': issue_body_str,
'labels': ", ".join(labels),
'sub_issues': sub_issues_content # Store sub-issues content
})
tickets_content.append(
{
'ticket_id': issue_main.number,
'ticket_url': ticket,
'title': issue_main.title,
'body': issue_body_str,
'labels': ", ".join(labels),
'sub_issues': sub_issues_content, # Store sub-issues content
}
)
return tickets_content
except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()})
get_logger().error(
f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()},
)
async def extract_and_cache_pr_tickets(git_provider, vars):
@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars):
related_tickets.append(ticket)
get_logger().info("Extracted tickets and sub-issues from PR description",
artifact={"tickets": related_tickets})
get_logger().info(
"Extracted tickets and sub-issues from PR description",
artifact={"tickets": related_tickets},
)
vars['related_tickets'] = related_tickets
get_settings().set('related_tickets', related_tickets)

13
config.ini Normal file
View File

@ -0,0 +1,13 @@
[BASE]
; 是否开启debug模式 0或1
DEBUG = 0
[DATABASE]
; 默认采用sqlite线上需替换为pg
DEFAULT = pg
; postgres配置
DB_NAME = pr_manager
DB_USER = admin
DB_PASSWORD = admin123456
DB_HOST = 110.40.30.95
DB_PORT = 5432

View File

@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
import os
import sys
import configparser
from pathlib import Path
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
CONFIG_NAME = BASE_DIR / "config.ini"
# 加载配置文件: 开发可加载config.local.ini
_config = configparser.ConfigParser()
_config.read(CONFIG_NAME, encoding="utf-8")
sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = False
DEBUG = bool(int(_config["BASE"].get("DEBUG", "1")))
ALLOWED_HOSTS = ["*"]
@ -44,7 +51,7 @@ INSTALLED_APPS = [
"django.contrib.messages",
"django.contrib.staticfiles",
"public",
"pr"
"pr",
]
# 配置安全秘钥
@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls"
TEMPLATES = [
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [BASE_DIR / 'templates']
,
"DIRS": [BASE_DIR / 'templates'],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
@ -89,12 +95,22 @@ WSGI_APPLICATION = "pr_manager.wsgi.application"
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
DATABASES = {
"default": {
"pg": {
"ENGINE": "django.db.backends.postgresql",
"NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"),
"USER": _config["DATABASE"].get("DB_USER", "admin"),
"PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"),
"HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"),
"PORT": int(_config["DATABASE"].get("DB_PORT", "5432")),
},
"sqlite": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3",
}
},
}
DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")]
# Password validation
# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators