From de84796560558b4866d5bc34198cd47ebcc7ec60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E5=B9=B3?= Date: Thu, 27 Feb 2025 11:07:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E6=B8=85=E6=99=B0=E5=BA=A6=E5=92=8C=E5=8F=AF?= =?UTF-8?q?=E7=BB=B4=E6=8A=A4=E6=80=A7=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- Pipfile | 2 +- Pipfile.lock | 106 ++- apps/pr/admin.py | 8 +- apps/pr/management/commands/init_data.py | 1 - apps/pr/models.py | 15 +- apps/pr/views.py | 20 +- apps/utils/constant.py | 12 +- apps/utils/git_config.py | 52 +- apps/utils/pr_agent/agent/pr_agent.py | 12 +- apps/utils/pr_agent/algo/__init__.py | 4 +- .../algo/ai_handlers/base_ai_handler.py | 9 +- .../algo/ai_handlers/langchain_ai_handler.py | 20 +- .../algo/ai_handlers/litellm_ai_handler.py | 121 ++- .../algo/ai_handlers/openai_ai_handler.py | 32 +- apps/utils/pr_agent/algo/cli_args.py | 7 +- apps/utils/pr_agent/algo/file_filter.py | 22 +- .../pr_agent/algo/git_patch_processing.py | 307 +++++-- apps/utils/pr_agent/algo/language_handler.py | 20 +- apps/utils/pr_agent/algo/pr_processing.py | 398 ++++++--- apps/utils/pr_agent/algo/token_handler.py | 15 +- apps/utils/pr_agent/algo/utils.py | 446 +++++++--- apps/utils/pr_agent/cli.py | 40 +- apps/utils/pr_agent/cli_pip.py | 4 +- apps/utils/pr_agent/config_loader.py | 43 +- apps/utils/pr_agent/git_providers/__init__.py | 9 +- .../git_providers/azuredevops_provider.py | 282 ++++--- .../git_providers/bitbucket_provider.py | 284 ++++--- .../bitbucket_server_provider.py | 244 ++++-- .../git_providers/codecommit_client.py | 79 +- .../git_providers/codecommit_provider.py | 117 ++- .../pr_agent/git_providers/gerrit_provider.py | 138 ++-- .../pr_agent/git_providers/git_provider.py | 140 +++- .../pr_agent/git_providers/github_provider.py | 550 +++++++++---- .../pr_agent/git_providers/gitlab_provider.py | 364 ++++++--- .../git_providers/local_git_provider.py | 125 ++- apps/utils/pr_agent/git_providers/utils.py | 44 +- .../pr_agent/identity_providers/__init__.py | 9 +- .../default_identity_provider.py | 6 +- apps/utils/pr_agent/log/__init__.py | 6 +- .../pr_agent/secret_providers/__init__.py | 10 +- .../google_cloud_storage_secret_provider.py | 17 +- .../secret_providers/secret_provider.py | 1 - .../servers/azuredevops_server_webhook.py | 66 +- apps/utils/pr_agent/servers/bitbucket_app.py | 94 ++- .../servers/bitbucket_server_webhook.py | 25 +- apps/utils/pr_agent/servers/gerrit_server.py | 6 +- .../pr_agent/servers/github_action_runner.py | 66 +- apps/utils/pr_agent/servers/github_app.py | 290 +++++-- apps/utils/pr_agent/servers/github_polling.py | 169 ++-- apps/utils/pr_agent/servers/gitlab_webhook.py | 167 ++-- apps/utils/pr_agent/servers/help.py | 39 +- apps/utils/pr_agent/servers/utils.py | 15 +- apps/utils/pr_agent/tools/pr_add_docs.py | 128 ++- .../pr_agent/tools/pr_code_suggestions.py | 763 +++++++++++++----- apps/utils/pr_agent/tools/pr_config.py | 42 +- apps/utils/pr_agent/tools/pr_description.py | 511 ++++++++---- .../pr_agent/tools/pr_generate_labels.py | 28 +- apps/utils/pr_agent/tools/pr_help_message.py | 212 +++-- .../utils/pr_agent/tools/pr_line_questions.py | 66 +- apps/utils/pr_agent/tools/pr_questions.py | 70 +- apps/utils/pr_agent/tools/pr_reviewer.py | 296 +++++-- apps/utils/pr_agent/tools/pr_similar_issue.py | 184 +++-- .../pr_agent/tools/pr_update_changelog.py | 70 +- .../tools/ticket_pr_compliance_check.py | 115 ++- config.ini | 13 + pr_manager/settings.py | 28 +- 67 files changed, 5417 insertions(+), 2190 deletions(-) create mode 100644 config.ini diff --git a/.gitignore b/.gitignore index 35ca4c7..7c9d120 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ docs/.cache/ .qodo db.sqlite3 #pr_agent/ -static/admin/ \ No newline at end of file +static/admin/ +config.local.ini diff --git a/Pipfile b/Pipfile index ee1125c..21a48b6 100644 --- a/Pipfile +++ b/Pipfile @@ -20,9 +20,9 @@ pygithub = "*" python-gitlab = "*" retry = "*" fastapi = "*" +psycopg2-binary = "*" [dev-packages] [requires] python_version = "3.12" - diff --git a/Pipfile.lock b/Pipfile.lock index 8cc7042..8632a9b 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "420206f7faa4351eabc368a83deae9b7ed9e50b0975ac63a46d6367e9920848b" + "sha256": "497c1ff8497659883faf8dcca407665df1b3a37f67720f64b139f9dec8202892" }, "pipfile-spec": 6, "requires": { @@ -169,20 +169,19 @@ }, "boto3": { "hashes": [ - "sha256:01015b38017876d79efd7273f35d9a4adfba505237159621365bed21b9b65eca", - "sha256:03bd8c93b226f07d944fd6b022e11a307bff94ab6a21d51675d7e3ea81ee8424" + "sha256:e58136d52d79425ce26c3c1578bf94d4b2e91ead55fed9f6950406ee9713e6af" ], "index": "pip_conf_index_global", "markers": "python_version >= '3.8'", - "version": "==1.37.0" + "version": "==1.37.2" }, "botocore": { "hashes": [ - "sha256:b129d091a8360b4152ab65327186bf4e250de827c4a9b7ddf40a72b1acf1f3c1", - "sha256:d01661f38c0edac87424344cdf4169f3ab9bc1bf1b677c8b230d025eb66c54a3" + "sha256:3f460f3c32cd6d747d5897a9cbde011bf1715abc7bf0a6ea6fdb0b812df63287", + "sha256:5f59b966f3cd0c8055ef6f7c2600f7db5f8218071d992e5f95da3f9156d4370f" ], "markers": "python_version >= '3.8'", - "version": "==1.37.0" + "version": "==1.37.2" }, "certifi": { "hashes": [ @@ -460,12 +459,12 @@ }, "django-import-export": { "hashes": [ - "sha256:317842a64233025a277040129fb6792fc48fd39622c185b70bf8c18c393d708f", - "sha256:ecb4e6cdb4790d69bce261f9cca1007ca19cb431bb5a950ba907898245c8817b" + "sha256:5514d09636e84e823a42cd5e79292f70f20d6d2feed117a145f5b64a5b44f168", + "sha256:bd3fe0aa15a2bce9de4be1a2f882e2c4539fdbfdfa16f2052c98dd7aec0f085c" ], "index": "pip_conf_index_global", "markers": "python_version >= '3.9'", - "version": "==4.3.6" + "version": "==4.3.7" }, "django-simpleui": { "hashes": [ @@ -794,12 +793,12 @@ }, "litellm": { "hashes": [ - "sha256:02df5865f98ea9734a4d27ac7c33aad9a45c4015403d5c0797d3292ade3c5cb5", - "sha256:d241436ac0edf64ec57fb5686f8d84a25998a7e52213d9063adf87df8432701f" + "sha256:eaab989c090ccc094b41c3fdf27d1df7f6fb25e091ab0ce48e0f3079f1e51ff5", + "sha256:ff9137c008cdb421db32defb1fbd1ed546a95167de6d276c61b664582ed4ff60" ], "index": "pip_conf_index_global", "markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'", - "version": "==1.61.16" + "version": "==1.61.17" }, "loguru": { "hashes": [ @@ -1196,6 +1195,81 @@ "markers": "python_version >= '3.6'", "version": "==7.0.0" }, + "psycopg2-binary": { + "hashes": [ + "sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff", + "sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5", + "sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f", + "sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5", + "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0", + "sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c", + "sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c", + "sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341", + "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f", + "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7", + "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d", + "sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007", + "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", + "sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92", + "sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb", + "sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5", + "sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5", + "sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8", + "sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1", + "sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68", + "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73", + "sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1", + "sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53", + "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d", + "sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906", + "sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0", + "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2", + "sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a", + "sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b", + "sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44", + "sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648", + "sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7", + "sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f", + "sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa", + "sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697", + "sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d", + "sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b", + "sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526", + "sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4", + "sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287", + "sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e", + "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673", + "sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0", + "sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30", + "sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3", + "sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e", + "sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92", + "sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a", + "sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c", + "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8", + "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909", + "sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47", + "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864", + "sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc", + "sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00", + "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb", + "sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539", + "sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b", + "sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481", + "sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5", + "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4", + "sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64", + "sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392", + "sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4", + "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1", + "sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1", + "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567", + "sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863" + ], + "index": "pip_conf_index_global", + "markers": "python_version >= '3.8'", + "version": "==2.9.10" + }, "py": { "hashes": [ "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", @@ -1766,11 +1840,11 @@ }, "s3transfer": { "hashes": [ - "sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f", - "sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc" + "sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d", + "sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a" ], "markers": "python_version >= '3.8'", - "version": "==0.11.2" + "version": "==0.11.3" }, "simplepro": { "hashes": [ diff --git a/apps/pr/admin.py b/apps/pr/admin.py index b15569b..78a659a 100644 --- a/apps/pr/admin.py +++ b/apps/pr/admin.py @@ -34,7 +34,13 @@ class GitConfigAdmin(AjaxAdmin): class ProjectConfigAdmin(AjaxAdmin): """Admin配置""" - list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"] + list_display = [ + "project_id", + "project_name", + "project_secret", + "commands", + "is_enable", + ] readonly_fields = ["create_by", "delete_at", "detail"] top_html = '' diff --git a/apps/pr/management/commands/init_data.py b/apps/pr/management/commands/init_data.py index 3a35864..1afc37f 100644 --- a/apps/pr/management/commands/init_data.py +++ b/apps/pr/management/commands/init_data.py @@ -16,4 +16,3 @@ class Command(BaseCommand): print("初始化AI配置已创建") else: print("初始化AI配置已存在") - diff --git a/apps/pr/models.py b/apps/pr/models.py index 971f6e9..9e1cf5f 100644 --- a/apps/pr/models.py +++ b/apps/pr/models.py @@ -44,9 +44,7 @@ class GitConfig(BaseModel): null=True, blank=True, max_length=16, verbose_name="Git名称" ) git_type = fields.RadioField( - choices=constant.GIT_TYPE, - default=0, - verbose_name="Git类型" + choices=constant.GIT_TYPE, default=0, verbose_name="Git类型" ) git_url = fields.CharField( null=True, blank=True, max_length=128, verbose_name="Git地址" @@ -67,6 +65,7 @@ class ProjectConfig(BaseModel): """ 项目配置表 """ + git_config = fields.ForeignKey( GitConfig, null=True, @@ -89,10 +88,7 @@ class ProjectConfig(BaseModel): max_length=256, verbose_name="默认命令", ) - is_enable = fields.SwitchField( - default=True, - verbose_name="是否启用" - ) + is_enable = fields.SwitchField(default=True, verbose_name="是否启用") class Meta: verbose_name = "项目配置" @@ -106,6 +102,7 @@ class ProjectHistory(BaseModel): """ 项目历史表 """ + project = fields.ForeignKey( ProjectConfig, null=True, @@ -128,9 +125,7 @@ class ProjectHistory(BaseModel): mr_title = fields.CharField( null=True, blank=True, max_length=256, verbose_name="MR标题" ) - source_data = models.JSONField( - null=True, blank=True, verbose_name="源数据" - ) + source_data = models.JSONField(null=True, blank=True, verbose_name="源数据") class Meta: verbose_name = "项目历史" diff --git a/apps/pr/views.py b/apps/pr/views.py index ccc8991..2821ea0 100644 --- a/apps/pr/views.py +++ b/apps/pr/views.py @@ -12,12 +12,7 @@ from utils import constant def load_project_config( - git_url, - access_token, - project_secret, - openai_api_base, - openai_key, - llm_model + git_url, access_token, project_secret, openai_api_base, openai_key, llm_model ): """ 加载项目配置 @@ -36,12 +31,11 @@ def load_project_config( "secret": project_secret, "openai_api_base": openai_api_base, "openai_key": openai_key, - "llm_model": llm_model + "llm_model": llm_model, } class WebHookView(View): - @staticmethod def select_git_provider(git_type): """ @@ -82,7 +76,9 @@ class WebHookView(View): project_config = provider.get_project_config(project_id=project_id) # Token 校验 - provider.check_secret(request_headers=headers, project_secret=project_config.get("project_secret")) + provider.check_secret( + request_headers=headers, project_secret=project_config.get("project_secret") + ) provider.get_merge_request( request_data=json_data, @@ -91,11 +87,13 @@ class WebHookView(View): api_base=project_config.get("api_base"), api_key=project_config.get("api_key"), llm_model=project_config.get("llm_model"), - project_commands=project_config.get("commands") + project_commands=project_config.get("commands"), ) # 记录请求日志: 目前仅记录合并日志 if json_data.get('object_kind') == 'merge_request': - provider.save_pr_agent_log(request_data=json_data, project_id=project_config.get("project_id")) + provider.save_pr_agent_log( + request_data=json_data, project_id=project_config.get("project_id") + ) return JsonResponse(status=200, data={"status": "ignored"}) diff --git a/apps/utils/constant.py b/apps/utils/constant.py index 6c3a7e9..903d9a7 100644 --- a/apps/utils/constant.py +++ b/apps/utils/constant.py @@ -1,8 +1,4 @@ -GIT_TYPE = ( - (0, "gitlab"), - (1, "github"), - (2, "gitea") -) +GIT_TYPE = ((0, "gitlab"), (1, "github"), (2, "gitea")) DEFAULT_COMMANDS = ( ("/review", "/review"), @@ -10,11 +6,7 @@ DEFAULT_COMMANDS = ( ("/improve_code", "/improve_code"), ) -UA_TYPE = { - "GitLab": "gitlab", - "GitHub": "github", - "Go-http-client": "gitea" -} +UA_TYPE = {"GitLab": "gitlab", "GitHub": "github", "Go-http-client": "gitea"} def get_git_type_from_ua(ua_value): diff --git a/apps/utils/git_config.py b/apps/utils/git_config.py index 72c4900..6a03260 100644 --- a/apps/utils/git_config.py +++ b/apps/utils/git_config.py @@ -16,14 +16,14 @@ class GitProvider(ABC): @abstractmethod def get_merge_request( - self, - request_data, - git_url, - access_token, - api_base, - api_key, - llm_model, - project_commands + self, + request_data, + git_url, + access_token, + api_base, + api_key, + llm_model, + project_commands, ): pass @@ -33,7 +33,6 @@ class GitProvider(ABC): class GitLabProvider(GitProvider): - @staticmethod def check_secret(request_headers, project_secret): """ @@ -79,18 +78,18 @@ class GitLabProvider(GitProvider): "access_token": git_config.access_token, "project_secret": project_config.project_secret, "commands": project_config.commands.split(","), - "project_id": project_config.id + "project_id": project_config.id, } def get_merge_request( - self, - request_data, - git_url, - access_token, - api_base, - api_key, - llm_model, - project_commands, + self, + request_data, + git_url, + access_token, + api_base, + api_key, + llm_model, + project_commands, ): """ 实现GitLab Merge Request获取逻辑 @@ -124,7 +123,10 @@ class GitLabProvider(GitProvider): self.run_command(mr_url, project_commands) # 数据库留存 return JsonResponse(status=200, data={"status": "review started"}) - return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"}) + return JsonResponse( + status=400, + data={"error": "Merge request URL not found or action not open"}, + ) @staticmethod def save_pr_agent_log(request_data, project_id): @@ -134,13 +136,19 @@ class GitLabProvider(GitProvider): :param project_id: :return: """ - if request_data.get('object_attributes', {}).get("source_branch") and request_data.get('object_attributes', {}).get("target_branch"): + if request_data.get('object_attributes', {}).get( + "source_branch" + ) and request_data.get('object_attributes', {}).get("target_branch"): models.ProjectHistory.objects.create( project_id=project_id, project_url=request_data.get("project", {}).get("web_url"), mr_url=request_data.get('object_attributes', {}).get("url"), - source_branch=request_data.get('object_attributes', {}).get("source_branch"), - target_branch=request_data.get('object_attributes', {}).get("target_branch"), + source_branch=request_data.get('object_attributes', {}).get( + "source_branch" + ), + target_branch=request_data.get('object_attributes', {}).get( + "target_branch" + ), mr_title=request_data.get('object_attributes', {}).get("title"), source_data=request_data, ) diff --git a/apps/utils/pr_agent/agent/pr_agent.py b/apps/utils/pr_agent/agent/pr_agent.py index 025254b..6ce86d3 100644 --- a/apps/utils/pr_agent/agent/pr_agent.py +++ b/apps/utils/pr_agent/agent/pr_agent.py @@ -80,14 +80,20 @@ class PRAgent: if action == "answer": if notify: notify() - await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() + await PRReviewer( + pr_url, is_answer=True, args=args, ai_handler=self.ai_handler + ).run() elif action == "auto_review": - await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() + await PRReviewer( + pr_url, is_auto=True, args=args, ai_handler=self.ai_handler + ).run() elif action in command2class: if notify: notify() - await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() + await command2class[action]( + pr_url, ai_handler=self.ai_handler, args=args + ).run() else: return False return True diff --git a/apps/utils/pr_agent/algo/__init__.py b/apps/utils/pr_agent/algo/__init__.py index 37ca48a..e4e7e4c 100644 --- a/apps/utils/pr_agent/algo/__init__.py +++ b/apps/utils/pr_agent/algo/__init__.py @@ -88,7 +88,7 @@ USER_MESSAGE_ONLY_MODELS = [ "deepseek/deepseek-reasoner", "o1-mini", "o1-mini-2024-09-12", - "o1-preview" + "o1-preview", ] NO_SUPPORT_TEMPERATURE_MODELS = [ @@ -99,5 +99,5 @@ NO_SUPPORT_TEMPERATURE_MODELS = [ "o1-2024-12-17", "o3-mini", "o3-mini-2025-01-31", - "o1-preview" + "o1-preview", ] diff --git a/apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py b/apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py index 956fcaf..064f522 100644 --- a/apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -16,7 +16,14 @@ class BaseAiHandler(ABC): pass @abstractmethod - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): + async def chat_completion( + self, + model: str, + system: str, + user: str, + temperature: float = 0.2, + img_path: str = None, + ): """ This method should be implemented to return a chat completion from the AI model. Args: diff --git a/apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py index af2c9b7..9673825 100644 --- a/apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -34,9 +34,16 @@ class LangChainOpenAIHandler(BaseAiHandler): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + @retry( + exceptions=(APIError, Timeout, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, + delay=2, + backoff=2, + jitter=(1, 3), + ) + async def chat_completion( + self, model: str, system: str, user: str, temperature: float = 0.2 + ): try: messages = [SystemMessage(content=system), HumanMessage(content=user)] @@ -45,7 +52,7 @@ class LangChainOpenAIHandler(BaseAiHandler): finish_reason = "completed" return resp.content, finish_reason - except (Exception) as e: + except Exception as e: get_logger().error("Unknown error during OpenAI inference: ", e) raise e @@ -66,7 +73,10 @@ class LangChainOpenAIHandler(BaseAiHandler): if openai_api_base is None or len(openai_api_base) == 0: return ChatOpenAI(openai_api_key=get_settings().openai.key) else: - return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base) + return ChatOpenAI( + openai_api_key=get_settings().openai.key, + openai_api_base=openai_api_base, + ) except AttributeError as e: if getattr(e, "name"): raise ValueError(f"OpenAI {e.name} is required") from e diff --git a/apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 5a10640..7aa0876 100644 --- a/apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -36,9 +36,14 @@ class LiteLLMAIHandler(BaseAiHandler): elif 'OPENAI_API_KEY' not in os.environ: litellm.api_key = "dummy_key" if get_settings().get("aws.AWS_ACCESS_KEY_ID"): - assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete" + assert ( + get_settings().aws.AWS_SECRET_ACCESS_KEY + and get_settings().aws.AWS_REGION_NAME + ), "AWS credentials are incomplete" os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID - os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY + os.environ[ + "AWS_SECRET_ACCESS_KEY" + ] = get_settings().aws.AWS_SECRET_ACCESS_KEY os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME if get_settings().get("litellm.use_client"): litellm_token = get_settings().get("litellm.LITELLM_TOKEN") @@ -73,14 +78,19 @@ class LiteLLMAIHandler(BaseAiHandler): litellm.replicate_key = get_settings().replicate.key if get_settings().get("HUGGINGFACE.KEY", None): litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: + if ( + get_settings().get("HUGGINGFACE.API_BASE", None) + and 'huggingface' in get_settings().config.model + ): litellm.api_base = get_settings().huggingface.api_base self.api_base = get_settings().huggingface.api_base if get_settings().get("OLLAMA.API_BASE", None): litellm.api_base = get_settings().ollama.api_base self.api_base = get_settings().ollama.api_base if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None): - self.repetition_penalty = float(get_settings().huggingface.repetition_penalty) + self.repetition_penalty = float( + get_settings().huggingface.repetition_penalty + ) if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): litellm.vertex_project = get_settings().vertexai.vertex_project litellm.vertex_location = get_settings().get( @@ -89,7 +99,9 @@ class LiteLLMAIHandler(BaseAiHandler): # Google AI Studio # SEE https://docs.litellm.ai/docs/providers/gemini if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None): - os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key + os.environ[ + "GEMINI_API_KEY" + ] = get_settings().google_ai_studio.gemini_api_key # Support deepseek models if get_settings().get("DEEPSEEK.KEY", None): @@ -140,27 +152,35 @@ class LiteLLMAIHandler(BaseAiHandler): git_provider = get_settings().config.git_provider metadata = dict() - callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback + callbacks = ( + litellm.success_callback + + litellm.failure_callback + + litellm.service_callback + ) if "langfuse" in callbacks: - metadata.update({ - "trace_name": command, - "tags": [git_provider, command, f'version:{get_version()}'], - "trace_metadata": { - "command": command, - "pr_url": pr_url, - }, - }) - if "langsmith" in callbacks: - metadata.update({ - "run_name": command, - "tags": [git_provider, command, f'version:{get_version()}'], - "extra": { - "metadata": { + metadata.update( + { + "trace_name": command, + "tags": [git_provider, command, f'version:{get_version()}'], + "trace_metadata": { "command": command, "pr_url": pr_url, - } - }, - }) + }, + } + ) + if "langsmith" in callbacks: + metadata.update( + { + "run_name": command, + "tags": [git_provider, command, f'version:{get_version()}'], + "extra": { + "metadata": { + "command": command, + "pr_url": pr_url, + } + }, + } + ) # Adding the captured logs to the kwargs kwargs["metadata"] = metadata @@ -175,10 +195,19 @@ class LiteLLMAIHandler(BaseAiHandler): return get_settings().get("OPENAI.DEPLOYMENT_ID", None) @retry( - retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError - stop=stop_after_attempt(OPENAI_RETRIES) + retry=retry_if_exception_type( + (openai.APIError, openai.APIConnectionError, openai.APITimeoutError) + ), # No retry on RateLimitError + stop=stop_after_attempt(OPENAI_RETRIES), ) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): + async def chat_completion( + self, + model: str, + system: str, + user: str, + temperature: float = 0.2, + img_path: str = None, + ): try: resp, finish_reason = None, None deployment_id = self.deployment_id @@ -187,8 +216,12 @@ class LiteLLMAIHandler(BaseAiHandler): if 'claude' in model and not system: system = "No system prompt provided" get_logger().warning( - "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.") - messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error." + ) + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] if img_path: try: @@ -201,14 +234,21 @@ class LiteLLMAIHandler(BaseAiHandler): except Exception as e: get_logger().error(f"Error fetching image: {img_path}", e) return f"Error fetching image: {img_path}", "error" - messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, - {"type": "image_url", "image_url": {"url": img_path}}] + messages[1]["content"] = [ + {"type": "text", "text": messages[1]["content"]}, + {"type": "image_url", "image_url": {"url": img_path}}, + ] # Currently, some models do not support a separate system and user prompts - if model in self.user_message_only_models or get_settings().config.custom_reasoning_model: + if ( + model in self.user_message_only_models + or get_settings().config.custom_reasoning_model + ): user = f"{system}\n\n\n{user}" system = "" - get_logger().info(f"Using model {model}, combining system and user prompts") + get_logger().info( + f"Using model {model}, combining system and user prompts" + ) messages = [{"role": "user", "content": user}] kwargs = { "model": model, @@ -227,7 +267,10 @@ class LiteLLMAIHandler(BaseAiHandler): } # Add temperature only if model supports it - if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model: + if ( + model not in self.no_support_temperature_models + and not get_settings().config.custom_reasoning_model + ): kwargs["temperature"] = temperature if get_settings().litellm.get("enable_callbacks", False): @@ -235,7 +278,9 @@ class LiteLLMAIHandler(BaseAiHandler): seed = get_settings().config.get("seed", -1) if temperature > 0 and seed >= 0: - raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0") + raise ValueError( + f"Seed ({seed}) is not supported with temperature ({temperature}) > 0" + ) elif seed >= 0: get_logger().info(f"Using fixed seed of {seed}") kwargs["seed"] = seed @@ -253,10 +298,10 @@ class LiteLLMAIHandler(BaseAiHandler): except (openai.APIError, openai.APITimeoutError) as e: get_logger().warning(f"Error during LLM inference: {e}") raise - except (openai.RateLimitError) as e: + except openai.RateLimitError as e: get_logger().error(f"Rate limit error during LLM inference: {e}") raise - except (Exception) as e: + except Exception as e: get_logger().warning(f"Unknown error during LLM inference: {e}") raise openai.APIError from e if response is None or len(response["choices"]) == 0: @@ -267,7 +312,9 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().debug(f"\nAI response:\n{resp}") # log the full response for debugging - response_log = self.prepare_logs(response, system, user, resp, finish_reason) + response_log = self.prepare_logs( + response, system, user, resp, finish_reason + ) get_logger().debug("Full_response", artifact=response_log) # for CLI debugging diff --git a/apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py b/apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py index ac8950c..6906613 100644 --- a/apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py +++ b/apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -37,13 +37,23 @@ class OpenAIHandler(BaseAiHandler): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + @retry( + exceptions=(APIError, Timeout, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, + delay=2, + backoff=2, + jitter=(1, 3), + ) + async def chat_completion( + self, model: str, system: str, user: str, temperature: float = 0.2 + ): try: get_logger().info("System: ", system) get_logger().info("User: ", user) - messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] client = AsyncOpenAI() chat_completion = await client.chat.completions.create( model=model, @@ -53,15 +63,21 @@ class OpenAIHandler(BaseAiHandler): resp = chat_completion.choices[0].message.content finish_reason = chat_completion.choices[0].finish_reason usage = chat_completion.usage - get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, - model=model, usage=usage) + get_logger().info( + "AI response", + response=resp, + messages=messages, + finish_reason=finish_reason, + model=model, + usage=usage, + ) return resp, finish_reason except (APIError, Timeout) as e: get_logger().error("Error during OpenAI inference: ", e) raise - except (RateLimitError) as e: + except RateLimitError as e: get_logger().error("Rate limit error during OpenAI inference: ", e) raise - except (Exception) as e: + except Exception as e: get_logger().error("Unknown error during OpenAI inference: ", e) raise diff --git a/apps/utils/pr_agent/algo/cli_args.py b/apps/utils/pr_agent/algo/cli_args.py index 4432469..ba6eeca 100644 --- a/apps/utils/pr_agent/algo/cli_args.py +++ b/apps/utils/pr_agent/algo/cli_args.py @@ -1,6 +1,7 @@ from base64 import b64decode import hashlib + class CliArgs: @staticmethod def validate_user_args(args: list) -> (bool, str): @@ -23,12 +24,12 @@ class CliArgs: for arg in args: if arg.startswith('--'): arg_word = arg.lower() - arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key + arg_word = arg_word.replace( + '__', '.' + ) # replace double underscore with dot, e.g. --openai__key -> --openai.key for forbidden_arg_word in forbidden_cli_args: if forbidden_arg_word in arg_word: return False, forbidden_arg_word return True, "" except Exception as e: return False, str(e) - - diff --git a/apps/utils/pr_agent/algo/file_filter.py b/apps/utils/pr_agent/algo/file_filter.py index b66febd..9dd6356 100644 --- a/apps/utils/pr_agent/algo/file_filter.py +++ b/apps/utils/pr_agent/algo/file_filter.py @@ -4,7 +4,7 @@ import re from utils.pr_agent.config_loader import get_settings -def filter_ignored(files, platform = 'github'): +def filter_ignored(files, platform='github'): """ Filter out files that match the ignore patterns. """ @@ -15,7 +15,9 @@ def filter_ignored(files, platform = 'github'): if isinstance(patterns, str): patterns = [patterns] glob_setting = get_settings().ignore.glob - if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py + if isinstance( + glob_setting, str + ): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py glob_setting = glob_setting.strip('[]').split(",") patterns += [fnmatch.translate(glob) for glob in glob_setting] @@ -31,7 +33,9 @@ def filter_ignored(files, platform = 'github'): if files and isinstance(files, list): for r in compiled_patterns: if platform == 'github': - files = [f for f in files if (f.filename and not r.match(f.filename))] + files = [ + f for f in files if (f.filename and not r.match(f.filename)) + ] elif platform == 'bitbucket': # files = [f for f in files if (f.new.path and not r.match(f.new.path))] files_o = [] @@ -49,10 +53,18 @@ def filter_ignored(files, platform = 'github'): # files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))] files_o = [] for f in files: - if 'new_path' in f and f['new_path'] and not r.match(f['new_path']): + if ( + 'new_path' in f + and f['new_path'] + and not r.match(f['new_path']) + ): files_o.append(f) continue - if 'old_path' in f and f['old_path'] and not r.match(f['old_path']): + if ( + 'old_path' in f + and f['old_path'] + and not r.match(f['old_path']) + ): files_o.append(f) continue files = files_o diff --git a/apps/utils/pr_agent/algo/git_patch_processing.py b/apps/utils/pr_agent/algo/git_patch_processing.py index c06228d..2a328ff 100644 --- a/apps/utils/pr_agent/algo/git_patch_processing.py +++ b/apps/utils/pr_agent/algo/git_patch_processing.py @@ -8,9 +8,18 @@ from utils.pr_agent.config_loader import get_settings from utils.pr_agent.log import get_logger -def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, - patch_extra_lines_after=0, filename: str = "") -> str: - if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str: +def extend_patch( + original_file_str, + patch_str, + patch_extra_lines_before=0, + patch_extra_lines_after=0, + filename: str = "", +) -> str: + if ( + not patch_str + or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) + or not original_file_str + ): return patch_str original_file_str = decode_if_bytes(original_file_str) @@ -21,10 +30,17 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, return patch_str try: - extended_patch_str = process_patch_lines(patch_str, original_file_str, - patch_extra_lines_before, patch_extra_lines_after) + extended_patch_str = process_patch_lines( + patch_str, + original_file_str, + patch_extra_lines_before, + patch_extra_lines_after, + ) except Exception as e: - get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().warning( + f"Failed to extend patch: {e}", + artifact={"traceback": traceback.format_exc()}, + ) return patch_str return extended_patch_str @@ -48,13 +64,19 @@ def decode_if_bytes(original_file_str): def should_skip_patch(filename): patch_extension_skip_types = get_settings().config.patch_extension_skip_types if patch_extension_skip_types and filename: - return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types) + return any( + filename.endswith(skip_type) for skip_type in patch_extension_skip_types + ) return False -def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after): +def process_patch_lines( + patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after +): allow_dynamic_context = get_settings().config.allow_dynamic_context - patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context + patch_extra_lines_before_dynamic = ( + get_settings().config.max_extra_lines_before_dynamic_context + ) original_lines = original_file_str.splitlines() len_original_lines = len(original_lines) @@ -63,59 +85,122 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, is_valid_hunk = True start1, size1, start2, size2 = -1, -1, -1, -1 - RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") try: - for i,line in enumerate(patch_lines): + for i, line in enumerate(patch_lines): if line.startswith('@@'): match = RE_HUNK_HEADER.match(line) # identify hunk header if match: # finish processing previous hunk if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0): - delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]] + delta_lines = [ + f' {line}' + for line in original_lines[ + start1 + + size1 + - 1 : start1 + + size1 + - 1 + + patch_extra_lines_after + ] + ] extended_patch_lines.extend(delta_lines) - section_header, size1, size2, start1, start2 = extract_hunk_headers(match) + section_header, size1, size2, start1, start2 = extract_hunk_headers( + match + ) - is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1) + is_valid_hunk = check_if_hunk_lines_matches_to_file( + i, original_lines, patch_lines, start1 + ) + + if is_valid_hunk and ( + patch_extra_lines_before > 0 or patch_extra_lines_after > 0 + ): - if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0): def _calc_context_limits(patch_lines_before): extended_start1 = max(1, start1 - patch_lines_before) - extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after + extended_size1 = ( + size1 + + (start1 - extended_start1) + + patch_extra_lines_after + ) extended_start2 = max(1, start2 - patch_lines_before) - extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after - if extended_start1 - 1 + extended_size1 > len_original_lines: + extended_size2 = ( + size2 + + (start2 - extended_start2) + + patch_extra_lines_after + ) + if ( + extended_start1 - 1 + extended_size1 + > len_original_lines + ): # we cannot extend beyond the original file - delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines + delta_cap = ( + extended_start1 + - 1 + + extended_size1 + - len_original_lines + ) extended_size1 = max(extended_size1 - delta_cap, size1) extended_size2 = max(extended_size2 - delta_cap, size2) - return extended_start1, extended_size1, extended_start2, extended_size2 + return ( + extended_start1, + extended_size1, + extended_start2, + extended_size2, + ) if allow_dynamic_context: - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before_dynamic) - lines_before = original_lines[extended_start1 - 1:start1 - 1] + ( + extended_start1, + extended_size1, + extended_start2, + extended_size2, + ) = _calc_context_limits(patch_extra_lines_before_dynamic) + lines_before = original_lines[ + extended_start1 - 1 : start1 - 1 + ] found_header = False - for i, line, in enumerate(lines_before): + for ( + i, + line, + ) in enumerate(lines_before): if section_header in line: found_header = True # Update start and size in one line each - extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i - extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i + extended_start1, extended_start2 = ( + extended_start1 + i, + extended_start2 + i, + ) + extended_size1, extended_size2 = ( + extended_size1 - i, + extended_size2 - i, + ) # get_logger().debug(f"Found section header in line {i} before the hunk") section_header = '' break if not found_header: # get_logger().debug(f"Section header not found in the extra lines before the hunk") - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before) + ( + extended_start1, + extended_size1, + extended_start2, + extended_size2, + ) = _calc_context_limits(patch_extra_lines_before) else: - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before) + ( + extended_start1, + extended_size1, + extended_start2, + extended_size2, + ) = _calc_context_limits(patch_extra_lines_before) - delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]] + delta_lines = [ + f' {line}' + for line in original_lines[extended_start1 - 1 : start1 - 1] + ] # logic to remove section header if its in the extra delta lines (in dynamic context, this is also done) if section_header and not allow_dynamic_context: @@ -132,17 +217,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, extended_patch_lines.append('') extended_patch_lines.append( f'@@ -{extended_start1},{extended_size1} ' - f'+{extended_start2},{extended_size2} @@ {section_header}') + f'+{extended_start2},{extended_size2} @@ {section_header}' + ) extended_patch_lines.extend(delta_lines) # one to zero based continue extended_patch_lines.append(line) except Exception as e: - get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().warning( + f"Failed to extend patch: {e}", + artifact={"traceback": traceback.format_exc()}, + ) return patch_str # finish processing last hunk if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk: - delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] + delta_lines = original_lines[ + start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after + ] # add space at the beginning of each extra line delta_lines = [f' {line}' for line in delta_lines] extended_patch_lines.extend(delta_lines) @@ -158,11 +249,14 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1): """ is_valid_hunk = True try: - if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file + if ( + i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ' + ): # an existing line in the file if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip(): is_valid_hunk = False get_logger().error( - f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content") + f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content" + ) except: pass return is_valid_hunk @@ -195,8 +289,7 @@ def omit_deletion_hunks(patch_lines) -> str: added_patched = [] add_hunk = False inside_hunk = False - RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)") + RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)") for line in patch_lines: if line.startswith('@@'): @@ -221,8 +314,13 @@ def omit_deletion_hunks(patch_lines) -> str: return '\n'.join(added_patched) -def handle_patch_deletions(patch: str, original_file_content_str: str, - new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str: +def handle_patch_deletions( + patch: str, + original_file_content_str: str, + new_file_content_str: str, + file_name: str, + edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN, +) -> str: """ Handle entire file or deletion patches. @@ -239,11 +337,13 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, str: The modified patch with deletion hunks omitted. """ - if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN): + if not new_file_content_str and ( + edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN + ): # logic for handling deleted files - don't show patch, just show that the file was deleted if get_settings().config.verbosity_level > 0: get_logger().info(f"Processing file: {file_name}, minimizing deletion file") - patch = None # file was deleted + patch = None # file was deleted else: patch_lines = patch.splitlines() patch_new = omit_deletion_hunks(patch_lines) @@ -256,35 +356,35 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, def convert_to_hunks_with_lines_numbers(patch: str, file) -> str: """ - Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of - the file. + Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of + the file. - Args: - patch (str): The patch string to be converted. - file: An object containing the filename of the file being patched. + Args: + patch (str): The patch string to be converted. + file: An object containing the filename of the file being patched. - Returns: - str: A string with line numbers for each hunk, indicating the new and old content of the file. + Returns: + str: A string with line numbers for each hunk, indicating the new and old content of the file. - example output: -## src/file.ts -__new hunk__ -881 line1 -882 line2 -883 line3 -887 + line4 -888 + line5 -889 line6 -890 line7 -... -__old hunk__ - line1 - line2 -- line3 -- line4 - line5 - line6 - ... + example output: + ## src/file.ts + __new hunk__ + 881 line1 + 882 line2 + 883 line3 + 887 + line4 + 888 + line5 + 889 line6 + 890 line7 + ... + __old hunk__ + line1 + line2 + - line3 + - line4 + line5 + line6 + ... """ # if the file was deleted, return a message indicating that the file was deleted if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED: @@ -292,8 +392,7 @@ __old hunk__ patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n" patch_lines = patch.splitlines() - RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") new_content_lines = [] old_content_lines = [] match = None @@ -307,20 +406,32 @@ __old hunk__ if line.startswith('@@'): header_line = line match = RE_HUNK_HEADER.match(line) - if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines + if match and ( + new_content_lines or old_content_lines + ): # found a new hunk, split the previous lines if prev_header_line: patch_with_lines_str += f'\n{prev_header_line}\n' is_plus_lines = is_minus_lines = False if new_content_lines: - is_plus_lines = any([line.startswith('+') for line in new_content_lines]) + is_plus_lines = any( + [line.startswith('+') for line in new_content_lines] + ) if old_content_lines: - is_minus_lines = any([line.startswith('-') for line in old_content_lines]) - if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' + is_minus_lines = any( + [line.startswith('-') for line in old_content_lines] + ) + if ( + is_plus_lines or is_minus_lines + ): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused + patch_with_lines_str = ( + patch_with_lines_str.rstrip() + '\n__new hunk__\n' + ) for i, line_new in enumerate(new_content_lines): patch_with_lines_str += f"{start2 + i} {line_new}\n" if is_minus_lines: - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n' + patch_with_lines_str = ( + patch_with_lines_str.rstrip() + '\n__old hunk__\n' + ) for line_old in old_content_lines: patch_with_lines_str += f"{line_old}\n" new_content_lines = [] @@ -335,8 +446,12 @@ __old hunk__ elif line.startswith('-'): old_content_lines.append(line) else: - if not line and line_i: # if this line is empty and the next line is a hunk header, skip it - if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'): + if ( + not line and line_i + ): # if this line is empty and the next line is a hunk header, skip it + if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith( + '@@' + ): continue elif line_i + 1 == len(patch_lines): continue @@ -351,7 +466,9 @@ __old hunk__ is_plus_lines = any([line.startswith('+') for line in new_content_lines]) if old_content_lines: is_minus_lines = any([line.startswith('-') for line in old_content_lines]) - if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused + if ( + is_plus_lines or is_minus_lines + ): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' for i, line_new in enumerate(new_content_lines): patch_with_lines_str += f"{start2 + i} {line_new}\n" @@ -363,13 +480,16 @@ __old hunk__ return patch_with_lines_str.rstrip() -def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]: +def extract_hunk_lines_from_patch( + patch: str, file_name, line_start, line_end, side +) -> tuple[str, str]: try: patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n" selected_lines = "" patch_lines = patch.splitlines() RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)" + ) match = None start1, size1, start2, size2 = -1, -1, -1, -1 skip_hunk = False @@ -385,7 +505,9 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s match = RE_HUNK_HEADER.match(line) - section_header, size1, size2, start1, start2 = extract_hunk_headers(match) + section_header, size1, size2, start1, start2 = extract_hunk_headers( + match + ) # check if line range is in this hunk if side.lower() == 'left': @@ -400,15 +522,26 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s patch_with_lines_str += f'\n{header_line}\n' elif not skip_hunk: - if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end: + if ( + side.lower() == 'right' + and line_start <= start2 + selected_lines_num <= line_end + ): selected_lines += line + '\n' - if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end: + if ( + side.lower() == 'left' + and start1 <= selected_lines_num + start1 <= line_end + ): selected_lines += line + '\n' patch_with_lines_str += line + '\n' - if not line.startswith('-'): # currently we don't support /ask line for deleted lines + if not line.startswith( + '-' + ): # currently we don't support /ask line for deleted lines selected_lines_num += 1 except Exception as e: - get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failed to extract hunk lines from patch: {e}", + artifact={"traceback": traceback.format_exc()}, + ) return "", "" return patch_with_lines_str.rstrip(), selected_lines.rstrip() diff --git a/apps/utils/pr_agent/algo/language_handler.py b/apps/utils/pr_agent/algo/language_handler.py index 1a6e3bd..f8a5163 100644 --- a/apps/utils/pr_agent/algo/language_handler.py +++ b/apps/utils/pr_agent/algo/language_handler.py @@ -9,10 +9,14 @@ def filter_bad_extensions(files): bad_extensions = get_settings().bad_extensions.default if get_settings().config.use_extra_bad_extensions: bad_extensions += get_settings().bad_extensions.extra - return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)] + return [ + f + for f in files + if f.filename is not None and is_valid_file(f.filename, bad_extensions) + ] -def is_valid_file(filename:str, bad_extensions=None) -> bool: +def is_valid_file(filename: str, bad_extensions=None) -> bool: if not filename: return False if not bad_extensions: @@ -27,12 +31,16 @@ def sort_files_by_main_languages(languages: Dict, files: list): Sort files by their main language, put the files that are in the main language first and the rest files after """ # sort languages by their size - languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)] + languages_sorted_list = [ + k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True) + ] # languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True) # get all extensions for the languages main_extensions = [] language_extension_map_org = get_settings().language_extension_map_org - language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} + language_extension_map = { + k.lower(): v for k, v in language_extension_map_org.items() + } for language in languages_sorted_list: if language.lower() in language_extension_map: main_extensions.append(language_extension_map[language.lower()]) @@ -62,7 +70,9 @@ def sort_files_by_main_languages(languages: Dict, files: list): if extension_str in extensions: tmp.append(file) else: - if (file.filename not in rest_files) and (extension_str not in main_extensions_flat): + if (file.filename not in rest_files) and ( + extension_str not in main_extensions_flat + ): rest_files[file.filename] = file if len(tmp) > 0: files_sorted.append({"language": lang, "files": tmp}) diff --git a/apps/utils/pr_agent/algo/pr_processing.py b/apps/utils/pr_agent/algo/pr_processing.py index 19d29ae..ee1617c 100644 --- a/apps/utils/pr_agent/algo/pr_processing.py +++ b/apps/utils/pr_agent/algo/pr_processing.py @@ -7,18 +7,28 @@ from github import RateLimitExceededException from utils.pr_agent.algo.file_filter import filter_ignored from utils.pr_agent.algo.git_patch_processing import ( - convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions) + convert_to_hunks_with_lines_numbers, + extend_patch, + handle_patch_deletions, +) from utils.pr_agent.algo.language_handler import sort_files_by_main_languages from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.types import EDIT_TYPE -from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model +from utils.pr_agent.algo.utils import ( + ModelType, + clip_tokens, + get_max_tokens, + get_weak_model, +) from utils.pr_agent.config_loader import get_settings from utils.pr_agent.git_providers.git_provider import GitProvider from utils.pr_agent.log import get_logger DELETED_FILES_ = "Deleted files:\n" -MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n" +MORE_MODIFIED_FILES_ = ( + "Additional modified files (insufficient token budget to process):\n" +) ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n" @@ -29,45 +39,59 @@ MAX_EXTRA_LINES = 10 def cap_and_log_extra_lines(value, direction) -> int: if value > MAX_EXTRA_LINES: - get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}") + get_logger().warning( + f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}" + ) return MAX_EXTRA_LINES return value -def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, - model: str, - add_line_numbers_to_hunks: bool = False, - disable_extra_lines: bool = False, - large_pr_handling=False, - return_remaining_files=False): +def get_pr_diff( + git_provider: GitProvider, + token_handler: TokenHandler, + model: str, + add_line_numbers_to_hunks: bool = False, + disable_extra_lines: bool = False, + large_pr_handling=False, + return_remaining_files=False, +): if disable_extra_lines: PATCH_EXTRA_LINES_BEFORE = 0 PATCH_EXTRA_LINES_AFTER = 0 else: PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after - PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before") - PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after") + PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines( + PATCH_EXTRA_LINES_BEFORE, "before" + ) + PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines( + PATCH_EXTRA_LINES_AFTER, "after" + ) try: diff_files_original = git_provider.get_diff_files() except RateLimitExceededException as e: - get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") + get_logger().error( + f"Rate limit exceeded for git provider API. original message {e}" + ) raise diff_files = filter_ignored(diff_files_original) if diff_files != diff_files_original: try: - get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files") + get_logger().info( + f"Filtered out {len(diff_files_original) - len(diff_files)} files" + ) new_names = set([a.filename for a in diff_files]) orig_names = set([a.filename for a in diff_files_original]) get_logger().info(f"Filtered out files: {orig_names - new_names}") except Exception as e: pass - # get pr languages - pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) + pr_languages = sort_files_by_main_languages( + git_provider.get_languages(), diff_files + ) if pr_languages: try: get_logger().info(f"PR main language: {pr_languages[0]['language']}") @@ -76,24 +100,42 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, # generate a standard diff string, with patch extension patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks, - patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) + pr_languages, + token_handler, + add_line_numbers_to_hunks, + patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, + patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER, + ) # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): - get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, " - f"returning full diff.") + get_logger().info( + f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, " + f"returning full diff." + ) return "\n".join(patches_extended) # if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines) - get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " - f"pruning diff.") - patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ - pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling) + get_logger().info( + f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " + f"pruning diff." + ) + ( + patches_compressed_list, + total_tokens_list, + deleted_files_list, + remaining_files_list, + file_dict, + files_in_patches_list, + ) = pr_generate_compressed_diff( + pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling + ) if large_pr_handling and len(patches_compressed_list) > 1: - get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.") - return "" # return empty string, as we want to generate multiple patches with a different prompt + get_logger().info( + f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff." + ) + return "" # return empty string, as we want to generate multiple patches with a different prompt # return the first patch patches_compressed = patches_compressed_list[0] @@ -144,26 +186,37 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, if deleted_list_str: final_diff = final_diff + "\n\n" + deleted_list_str - get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " - f"deleted_list_str: {deleted_list_str}") + get_logger().debug( + f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " + f"deleted_list_str: {deleted_list_str}" + ) if not return_remaining_files: return final_diff else: return final_diff, remaining_files_list -def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str, - add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False): +def get_pr_diff_multiple_patchs( + git_provider: GitProvider, + token_handler: TokenHandler, + model: str, + add_line_numbers_to_hunks: bool = False, + disable_extra_lines: bool = False, +): try: diff_files_original = git_provider.get_diff_files() except RateLimitExceededException as e: - get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") + get_logger().error( + f"Rate limit exceeded for git provider API. original message {e}" + ) raise diff_files = filter_ignored(diff_files_original) if diff_files != diff_files_original: try: - get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files") + get_logger().info( + f"Filtered out {len(diff_files_original) - len(diff_files)} files" + ) new_names = set([a.filename for a in diff_files]) orig_names = set([a.filename for a in diff_files_original]) get_logger().info(f"Filtered out files: {orig_names - new_names}") @@ -171,24 +224,47 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH pass # get pr languages - pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) + pr_languages = sort_files_by_main_languages( + git_provider.get_languages(), diff_files + ) if pr_languages: try: get_logger().info(f"PR main language: {pr_languages[0]['language']}") except Exception as e: pass - patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ - pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True) + ( + patches_compressed_list, + total_tokens_list, + deleted_files_list, + remaining_files_list, + file_dict, + files_in_patches_list, + ) = pr_generate_compressed_diff( + pr_languages, + token_handler, + model, + add_line_numbers_to_hunks, + large_pr_handling=True, + ) - return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list + return ( + patches_compressed_list, + total_tokens_list, + deleted_files_list, + remaining_files_list, + file_dict, + files_in_patches_list, + ) -def pr_generate_extended_diff(pr_languages: list, - token_handler: TokenHandler, - add_line_numbers_to_hunks: bool, - patch_extra_lines_before: int = 0, - patch_extra_lines_after: int = 0) -> Tuple[list, int, list]: +def pr_generate_extended_diff( + pr_languages: list, + token_handler: TokenHandler, + add_line_numbers_to_hunks: bool, + patch_extra_lines_before: int = 0, + patch_extra_lines_after: int = 0, +) -> Tuple[list, int, list]: total_tokens = token_handler.prompt_tokens # initial tokens patches_extended = [] patches_extended_tokens = [] @@ -200,20 +276,33 @@ def pr_generate_extended_diff(pr_languages: list, continue # extend each patch with extra lines of context - extended_patch = extend_patch(original_file_content_str, patch, - patch_extra_lines_before, patch_extra_lines_after, file.filename) + extended_patch = extend_patch( + original_file_content_str, + patch, + patch_extra_lines_before, + patch_extra_lines_after, + file.filename, + ) if not extended_patch: - get_logger().warning(f"Failed to extend patch for file: {file.filename}") + get_logger().warning( + f"Failed to extend patch for file: {file.filename}" + ) continue if add_line_numbers_to_hunks: - full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file) + full_extended_patch = convert_to_hunks_with_lines_numbers( + extended_patch, file + ) else: full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n" # add AI-summary metadata to the patch - if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False): - full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch) + if file.ai_file_summary and get_settings().get( + "config.enable_ai_metadata", False + ): + full_extended_patch = add_ai_summary_top_patch( + file, full_extended_patch + ) patch_tokens = token_handler.count_tokens(full_extended_patch) file.tokens = patch_tokens @@ -224,9 +313,13 @@ def pr_generate_extended_diff(pr_languages: list, return patches_extended, total_tokens, patches_extended_tokens -def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, - convert_hunks_to_line_numbers: bool, - large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]: +def pr_generate_compressed_diff( + top_langs: list, + token_handler: TokenHandler, + model: str, + convert_hunks_to_line_numbers: bool, + large_pr_handling: bool, +) -> Tuple[list, list, list, list, dict, list]: deleted_files_list = [] # sort each one of the languages in top_langs by the number of tokens in the diff @@ -244,8 +337,13 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo continue # removing delete-only hunks - patch = handle_patch_deletions(patch, original_file_content_str, - new_file_content_str, file.filename, file.edit_type) + patch = handle_patch_deletions( + patch, + original_file_content_str, + new_file_content_str, + file.filename, + file.edit_type, + ) if patch is None: if file.filename not in deleted_files_list: deleted_files_list.append(file.filename) @@ -259,30 +357,54 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo # patch = add_ai_summary_top_patch(file, patch) new_patch_tokens = token_handler.count_tokens(patch) - file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type} + file_dict[file.filename] = { + 'patch': patch, + 'tokens': new_patch_tokens, + 'edit_type': file.edit_type, + } max_tokens_model = get_max_tokens(model) # first iteration files_in_patches_list = [] - remaining_files_list = [file.filename for file in sorted_files] - patches_list =[] + remaining_files_list = [file.filename for file in sorted_files] + patches_list = [] total_tokens_list = [] - total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict, - max_tokens_model, remaining_files_list, token_handler) + ( + total_tokens, + patches, + remaining_files_list, + files_in_patch_list, + ) = generate_full_patch( + convert_hunks_to_line_numbers, + file_dict, + max_tokens_model, + remaining_files_list, + token_handler, + ) patches_list.append(patches) total_tokens_list.append(total_tokens) files_in_patches_list.append(files_in_patch_list) # additional iterations (if needed) if large_pr_handling: - NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize - for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1): + NUMBER_OF_ALLOWED_ITERATIONS = ( + get_settings().pr_description.max_ai_calls - 1 + ) # one more call is to summarize + for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1): if remaining_files_list: - total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, - file_dict, - max_tokens_model, - remaining_files_list, token_handler) + ( + total_tokens, + patches, + remaining_files_list, + files_in_patch_list, + ) = generate_full_patch( + convert_hunks_to_line_numbers, + file_dict, + max_tokens_model, + remaining_files_list, + token_handler, + ) if patches: patches_list.append(patches) total_tokens_list.append(total_tokens) @@ -290,11 +412,24 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo else: break - return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list + return ( + patches_list, + total_tokens_list, + deleted_files_list, + remaining_files_list, + file_dict, + files_in_patches_list, + ) -def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler): - total_tokens = token_handler.prompt_tokens # initial tokens +def generate_full_patch( + convert_hunks_to_line_numbers, + file_dict, + max_tokens_model, + remaining_files_list_prev, + token_handler, +): + total_tokens = token_handler.prompt_tokens # initial tokens patches = [] remaining_files_list_new = [] files_in_patch_list = [] @@ -312,7 +447,10 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod continue # If the patch is too large, just show the file name - if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if ( + total_tokens + new_patch_tokens + > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + ): # Current logic is to skip the patch if it's too large # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # until we meet the requirements @@ -334,7 +472,9 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod return total_tokens, patches, remaining_files_list_new, files_in_patch_list -async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR): +async def retry_with_fallback_models( + f: Callable, model_type: ModelType = ModelType.REGULAR +): all_models = _get_all_models(model_type) all_deployments = _get_all_deployments(all_models) # try each (model, deployment_id) pair until one is successful, otherwise raise exception @@ -347,11 +487,11 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT get_settings().set("openai.deployment_id", deployment_id) return await f(model) except: - get_logger().warning( - f"Failed to generate prediction with {model}" - ) + get_logger().warning(f"Failed to generate prediction with {model}") if i == len(all_models) - 1: # If it's the last iteration - raise Exception(f"Failed to generate prediction with any model of {all_models}") + raise Exception( + f"Failed to generate prediction with any model of {all_models}" + ) def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: @@ -374,17 +514,21 @@ def _get_all_deployments(all_models: List[str]) -> List[str]: if fallback_deployments: all_deployments = [deployment_id] + fallback_deployments if len(all_deployments) < len(all_models): - raise ValueError(f"The number of deployments ({len(all_deployments)}) " - f"is less than the number of models ({len(all_models)})") + raise ValueError( + f"The number of deployments ({len(all_deployments)}) " + f"is less than the number of models ({len(all_models)})" + ) else: all_deployments = [deployment_id] * len(all_models) return all_deployments -def get_pr_multi_diffs(git_provider: GitProvider, - token_handler: TokenHandler, - model: str, - max_calls: int = 5) -> List[str]: +def get_pr_multi_diffs( + git_provider: GitProvider, + token_handler: TokenHandler, + model: str, + max_calls: int = 5, +) -> List[str]: """ Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file. The patches are split into multiple groups based on the maximum number of tokens allowed for the given model. @@ -404,13 +548,17 @@ def get_pr_multi_diffs(git_provider: GitProvider, try: diff_files = git_provider.get_diff_files() except RateLimitExceededException as e: - get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") + get_logger().error( + f"Rate limit exceeded for git provider API. original message {e}" + ) raise diff_files = filter_ignored(diff_files) # Sort files by main language - pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) + pr_languages = sort_files_by_main_languages( + git_provider.get_languages(), diff_files + ) # Sort files within each language group by tokens in descending order sorted_files = [] @@ -420,14 +568,19 @@ def get_pr_multi_diffs(git_provider: GitProvider, # Get the maximum number of extra lines before and after the patch PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after - PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before") + PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines( + PATCH_EXTRA_LINES_BEFORE, "before" + ) PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after") # try first a single run with standard diff string, with patch extension, and no deletions patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks=True, + pr_languages, + token_handler, + add_line_numbers_to_hunks=True, patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, - patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) + patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER, + ) # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): @@ -450,27 +603,50 @@ def get_pr_multi_diffs(git_provider: GitProvider, continue # Remove delete-only hunks - patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type) + patch = handle_patch_deletions( + patch, + original_file_content_str, + new_file_content_str, + file.filename, + file.edit_type, + ) if patch is None: continue patch = convert_to_hunks_with_lines_numbers(patch, file) # add AI-summary metadata to the patch - if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False): + if file.ai_file_summary and get_settings().get( + "config.enable_ai_metadata", False + ): patch = add_ai_summary_top_patch(file, patch) new_patch_tokens = token_handler.count_tokens(patch) - if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( - model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if ( + patch + and (token_handler.prompt_tokens + new_patch_tokens) + > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + ): if get_settings().config.get('large_patch_policy', 'skip') == 'skip': get_logger().warning(f"Patch too large, skipping: {file.filename}") continue elif get_settings().config.get('large_patch_policy') == 'clip': - delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens - patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens) + delta_tokens = ( + get_max_tokens(model) + - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + - token_handler.prompt_tokens + ) + patch_clipped = clip_tokens( + patch, + delta_tokens, + delete_last_line=True, + num_input_tokens=new_patch_tokens, + ) new_patch_tokens = token_handler.count_tokens(patch_clipped) - if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( - model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if ( + patch_clipped + and (token_handler.prompt_tokens + new_patch_tokens) + > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + ): get_logger().warning(f"Patch too large, skipping: {file.filename}") continue else: @@ -480,13 +656,16 @@ def get_pr_multi_diffs(git_provider: GitProvider, get_logger().warning(f"Patch too large, skipping: {file.filename}") continue - if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD): + if patch and ( + total_tokens + new_patch_tokens + > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + ): final_diff = "\n".join(patches) final_diff_list.append(final_diff) patches = [] total_tokens = token_handler.prompt_tokens call_number += 1 - if call_number > max_calls: # avoid creating new patches + if call_number > max_calls: # avoid creating new patches if get_settings().config.verbosity_level >= 2: get_logger().info(f"Reached max calls ({max_calls})") break @@ -497,7 +676,9 @@ def get_pr_multi_diffs(git_provider: GitProvider, patches.append(patch) total_tokens += new_patch_tokens if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}") + get_logger().info( + f"Tokens: {total_tokens}, last filename: {file.filename}" + ) # Add the last chunk if patches: @@ -515,7 +696,10 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files): if not pr_description_files: get_logger().warning(f"PR description files are empty.") return - available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files} + available_files = { + pr_file['full_file_name'].strip(): pr_file + for pr_file in pr_description_files + } diff_files = git_provider.get_diff_files() found_any_match = False for file in diff_files: @@ -524,11 +708,15 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files): file.ai_file_summary = available_files[filename] found_any_match = True if not found_any_match: - get_logger().error(f"Failed to find any matching files between PR description and diff files.", - artifact={"pr_description_files": pr_description_files}) + get_logger().error( + f"Failed to find any matching files between PR description and diff files.", + artifact={"pr_description_files": pr_description_files}, + ) except Exception as e: - get_logger().error(f"Failed to add AI metadata to diff files: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failed to add AI metadata to diff files: {e}", + artifact={"traceback": traceback.format_exc()}, + ) def add_ai_summary_top_patch(file, full_extended_patch): @@ -537,14 +725,18 @@ def add_ai_summary_top_patch(file, full_extended_patch): full_extended_patch_lines = full_extended_patch.split("\n") for i, line in enumerate(full_extended_patch_lines): if line.startswith("## File:") or line.startswith("## file:"): - full_extended_patch_lines.insert(i + 1, - f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}") + full_extended_patch_lines.insert( + i + 1, + f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}", + ) full_extended_patch = "\n".join(full_extended_patch_lines) return full_extended_patch # if no '## File: ...' was found return full_extended_patch except Exception as e: - get_logger().error(f"Failed to add AI summary to the top of the patch: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failed to add AI summary to the top of the patch: {e}", + artifact={"traceback": traceback.format_exc()}, + ) return full_extended_patch diff --git a/apps/utils/pr_agent/algo/token_handler.py b/apps/utils/pr_agent/algo/token_handler.py index 4e4c9af..01c9035 100644 --- a/apps/utils/pr_agent/algo/token_handler.py +++ b/apps/utils/pr_agent/algo/token_handler.py @@ -15,12 +15,17 @@ class TokenEncoder: @classmethod def get_token_encoder(cls): model = get_settings().config.model - if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance + if ( + cls._encoder_instance is None or model != cls._model + ): # Check without acquiring the lock for performance with cls._lock: # Lock acquisition to ensure thread safety if cls._encoder_instance is None or model != cls._model: cls._model = model - cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding( - "cl100k_base") + cls._encoder_instance = ( + encoding_for_model(cls._model) + if "gpt" in cls._model + else get_encoding("cl100k_base") + ) return cls._encoder_instance @@ -49,7 +54,9 @@ class TokenHandler: """ self.encoder = TokenEncoder.get_token_encoder() if pr is not None: - self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) + self.prompt_tokens = self._get_system_user_tokens( + pr, self.encoder, vars, system, user + ) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): """ diff --git a/apps/utils/pr_agent/algo/utils.py b/apps/utils/pr_agent/algo/utils.py index 3dba680..75383a4 100644 --- a/apps/utils/pr_agent/algo/utils.py +++ b/apps/utils/pr_agent/algo/utils.py @@ -41,10 +41,12 @@ class Range(BaseModel): column_start: int = -1 column_end: int = -1 + class ModelType(str, Enum): REGULAR = "regular" WEAK = "weak" + class PRReviewHeader(str, Enum): REGULAR = "## PR 评审指南" INCREMENTAL = "## 增量 PR 评审指南" @@ -57,7 +59,9 @@ class PRDescriptionHeader(str, Enum): def get_setting(key: str) -> Any: try: key = key.upper() - return context.get("settings", global_settings).get(key, global_settings.get(key, None)) + return context.get("settings", global_settings).get( + key, global_settings.get(key, None) + ) except Exception: return global_settings.get(key, None) @@ -72,14 +76,29 @@ def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str # Everything before the colon (inclusive) is wrapped in tags if only_markdown: if reference_link: - transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:] + transformed_string = ( + f"[**{text[:colon_position + 1]}**]({reference_link})\n" + + text[colon_position + 1 :] + ) else: - transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:] + transformed_string = ( + f"**{text[:colon_position + 1]}**\n" + + text[colon_position + 1 :] + ) else: if reference_link: - transformed_string = f"{text[:colon_position + 1]}
" + text[colon_position + 1:] + transformed_string = ( + f"{text[:colon_position + 1]}
" + + text[colon_position + 1 :] + ) else: - transformed_string = "" + text[:colon_position + 1] + "" +'
' + text[colon_position + 1:] + transformed_string = ( + "" + + text[: colon_position + 1] + + "" + + '
' + + text[colon_position + 1 :] + ) else: # If there's no ": ", return the original string transformed_string = text @@ -101,11 +120,14 @@ def unique_strings(input_list: List[str]) -> List[str]: seen.add(item) return unique_list -def convert_to_markdown_v2(output_data: dict, - gfm_supported: bool = True, - incremental_review=None, - git_provider=None, - files=None) -> str: + +def convert_to_markdown_v2( + output_data: dict, + gfm_supported: bool = True, + incremental_review=None, + git_provider=None, + files=None, +) -> str: """ Convert a dictionary of data into markdown format. Args: @@ -183,7 +205,9 @@ def convert_to_markdown_v2(output_data: dict, else: markdown_text += f"### {emoji} PR 包含测试\n\n" elif 'ticket compliance check' in key_nice.lower(): - markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) + markdown_text = ticket_markdown_logic( + emoji, markdown_text, value, gfm_supported + ) elif 'security concerns' in key_nice.lower(): if gfm_supported: markdown_text += f"" @@ -220,7 +244,9 @@ def convert_to_markdown_v2(output_data: dict, if gfm_supported: markdown_text += f"" # markdown_text += f"{emoji} {key_nice}

\n\n" - markdown_text += f"{emoji} 建议评审的重点领域

\n\n" + markdown_text += ( + f"{emoji} 建议评审的重点领域

\n\n" + ) else: markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n" for i, issue in enumerate(issues): @@ -235,9 +261,13 @@ def convert_to_markdown_v2(output_data: dict, start_line = int(str(issue.get('start_line', 0)).strip()) end_line = int(str(issue.get('end_line', 0)).strip()) - relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True) + relevant_lines_str = extract_relevant_lines_str( + end_line, files, relevant_file, start_line, dedent=True + ) if git_provider: - reference_link = git_provider.get_line_link(relevant_file, start_line, end_line) + reference_link = git_provider.get_line_link( + relevant_file, start_line, end_line + ) else: reference_link = None @@ -256,7 +286,9 @@ def convert_to_markdown_v2(output_data: dict, issue_str = f"**{issue_header}**\n\n{issue_content}\n\n" markdown_text += f"{issue_str}\n\n" except Exception as e: - get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}") + get_logger().exception( + f"Failed to process 'Recommended focus areas for review': {e}" + ) if gfm_supported: markdown_text += f"\n" else: @@ -273,7 +305,9 @@ def convert_to_markdown_v2(output_data: dict, return markdown_text -def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str: +def extract_relevant_lines_str( + end_line, files, relevant_file, start_line, dedent=False +) -> str: """ Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content. """ @@ -286,10 +320,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden if not file.head_file: # as a fallback, extract relevant lines directly from patch patch = file.patch - get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead") - _, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right') + get_logger().info( + f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead" + ) + _, selected_lines = extract_hunk_lines_from_patch( + patch, file.filename, start_line, end_line, side='right' + ) if not selected_lines: - get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}") + get_logger().error( + f"Failed to extract relevant lines from patch: {file.filename}" + ) return "" # filter out '-' lines relevant_lines_str = "" @@ -299,12 +339,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden relevant_lines_str += line[1:] + '\n' else: relevant_file_lines = file.head_file.splitlines() - relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line]) + relevant_lines_str = "\n".join( + relevant_file_lines[start_line - 1 : end_line] + ) if dedent and relevant_lines_str: # Remove the longest leading string of spaces and tabs common to all lines. relevant_lines_str = textwrap.dedent(relevant_lines_str) - relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```" + relevant_lines_str = ( + f"```{file.language}\n{relevant_lines_str}\n```" + ) break return relevant_lines_str @@ -325,14 +369,21 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: ticket_url = ticket_analysis.get('ticket_url', '').strip() explanation = '' ticket_compliance_level = '' # Individual ticket compliance - fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip() - not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip() - requires_further_human_verification = ticket_analysis.get('requires_further_human_verification', - '').strip() + fully_compliant_str = ticket_analysis.get( + 'fully_compliant_requirements', '' + ).strip() + not_compliant_str = ticket_analysis.get( + 'not_compliant_requirements', '' + ).strip() + requires_further_human_verification = ticket_analysis.get( + 'requires_further_human_verification', '' + ).strip() if not fully_compliant_str and not not_compliant_str: - get_logger().debug(f"Ticket compliance has no requirements", - artifact={'ticket_url': ticket_url}) + get_logger().debug( + f"Ticket compliance has no requirements", + artifact={'ticket_url': ticket_url}, + ) continue # Calculate individual ticket compliance level @@ -353,19 +404,27 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: # build compliance string if fully_compliant_str: - explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n" + explanation += ( + f"Compliant requirements:\n\n{fully_compliant_str}\n\n" + ) if not_compliant_str: - explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n" + explanation += ( + f"Non-compliant requirements:\n\n{not_compliant_str}\n\n" + ) if requires_further_human_verification: explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n" ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n" # for debugging if requires_further_human_verification: - get_logger().debug(f"Ticket compliance requires further human verification", - artifact={'ticket_url': ticket_url, - 'requires_further_human_verification': requires_further_human_verification, - 'compliance_level': ticket_compliance_level}) + get_logger().debug( + f"Ticket compliance requires further human verification", + artifact={ + 'ticket_url': ticket_url, + 'requires_further_human_verification': requires_further_human_verification, + 'compliance_level': ticket_compliance_level, + }, + ) except Exception as e: get_logger().exception(f"Failed to process ticket compliance: {e}") @@ -381,7 +440,10 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: compliance_emoji = '✅' elif any(level == 'Not compliant' for level in all_compliance_levels): # If there's a mix of compliant and non-compliant tickets - if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels): + if any( + level in ['Fully compliant', 'PR Code Verified'] + for level in all_compliance_levels + ): compliance_level = 'Partially compliant' compliance_emoji = '🔶' else: @@ -395,7 +457,9 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: compliance_emoji = '✅' # Set extra statistics outside the ticket loop - get_settings().set('config.extra_statistics', {'compliance_level': compliance_level}) + get_settings().set( + 'config.extra_statistics', {'compliance_level': compliance_level} + ) # editing table row for ticket compliance analysis if gfm_supported: @@ -425,7 +489,9 @@ def process_can_be_split(emoji, value): for i, split in enumerate(value): title = split.get('title', '') relevant_files = split.get('relevant_files', []) - markdown_text += f"
\n子 PR 主题: {title}\n\n" + markdown_text += ( + f"
\n子 PR 主题: {title}\n\n" + ) markdown_text += f"___\n\n相关文件:\n\n" for file in relevant_files: markdown_text += f"- {file}\n" @@ -464,7 +530,9 @@ def process_can_be_split(emoji, value): return markdown_text -def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str: +def parse_code_suggestion( + code_suggestion: dict, i: int = 0, gfm_supported: bool = True +) -> str: """ Convert a dictionary of data into markdown format. @@ -484,15 +552,19 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool markdown_text += f"相关文件{relevant_file}" # continue elif sub_key.lower() == 'suggestion': - markdown_text += (f"{sub_key}      " - f"\n\n\n\n{sub_value.strip()}\n\n\n") + markdown_text += ( + f"{sub_key}      " + f"\n\n\n\n{sub_value.strip()}\n\n\n" + ) elif sub_key.lower() == 'relevant_line': markdown_text += f"相关行" sub_value_list = sub_value.split('](') relevant_line = sub_value_list[0].lstrip('`').lstrip('[') if len(sub_value_list) > 1: link = sub_value_list[1].rstrip(')').strip('`') - markdown_text += f"{relevant_line}" + markdown_text += ( + f"{relevant_line}" + ) else: markdown_text += f"{relevant_line}" markdown_text += "" @@ -505,11 +577,14 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool for sub_key, sub_value in code_suggestion.items(): if isinstance(sub_key, str): sub_key = sub_key.rstrip() - if isinstance(sub_value,str): + if isinstance(sub_value, str): sub_value = sub_value.rstrip() if isinstance(sub_value, dict): # "code example" markdown_text += f" - **{sub_key}:**\n" - for code_key, code_value in sub_value.items(): # 'before' and 'after' code + for ( + code_key, + code_value, + ) in sub_value.items(): # 'before' and 'after' code code_str = f"```\n{code_value}\n```" code_str_indented = textwrap.indent(code_str, ' ') markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" @@ -520,7 +595,9 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool markdown_text += f" **{sub_key}:** {sub_value} \n" if "relevant_line" not in sub_key.lower(): # nicer presentation # markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab - markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker + markdown_text = ( + markdown_text.rstrip('\n') + " \n" + ) # works for gitlab and bitbucker markdown_text += "\n" return markdown_text @@ -561,9 +638,15 @@ def try_fix_json(review, max_iter=10, code_suggestions=False): else: closing_bracket = "]}}" - if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \ - (review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) : - last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 + if ( + review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0 + ) or ( + review.rfind("'Code suggestions': [") > 0 + or review.rfind('"Code suggestions": [') > 0 + ): + last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][ + -1 + ] - 1 valid_json = False iter_count = 0 @@ -574,7 +657,9 @@ def try_fix_json(review, max_iter=10, code_suggestions=False): review = review[:last_code_suggestion_ind].strip() + closing_bracket except json.decoder.JSONDecodeError: review = review[:last_code_suggestion_ind] - last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 + last_code_suggestion_ind = [ + m.end() for m in re.finditer(r"\}\s*,", review) + ][-1] - 1 iter_count += 1 if not valid_json: @@ -629,7 +714,12 @@ def convert_str_to_datetime(date_str): return datetime.strptime(date_str, datetime_format) -def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str: +def load_large_diff( + filename, + new_file_content_str: str, + original_file_content_str: str, + show_warning: bool = True, +) -> str: """ Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. @@ -640,10 +730,14 @@ def load_large_diff(filename, new_file_content_str: str, original_file_content_s try: original_file_content_str = (original_file_content_str or "").rstrip() + "\n" new_file_content_str = (new_file_content_str or "").rstrip() + "\n" - diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), - new_file_content_str.splitlines(keepends=True)) + diff = difflib.unified_diff( + original_file_content_str.splitlines(keepends=True), + new_file_content_str.splitlines(keepends=True), + ) if get_settings().config.verbosity_level >= 2 and show_warning: - get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.") + get_logger().info( + f"File was modified, but no patch was found. Manually creating patch: {filename}." + ) patch = ''.join(diff) return patch except Exception as e: @@ -693,42 +787,68 @@ def _fix_key_value(key: str, value: str): try: value = yaml.safe_load(value) except Exception as e: - get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e) + get_logger().debug( + f"Failed to parse YAML for config override {key}={value}", exc_info=e + ) return key, value -def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict: - response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```') +def load_yaml( + response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="" +) -> dict: + response_text = ( + response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```') + ) try: data = yaml.safe_load(response_text) except Exception as e: get_logger().warning(f"Initial failure to parse AI prediction: {e}") - data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key) + data = try_fix_yaml( + response_text, + keys_fix_yaml=keys_fix_yaml, + first_key=first_key, + last_key=last_key, + ) if not data: - get_logger().error(f"Failed to parse AI prediction after fallbacks", - artifact={'response_text': response_text}) + get_logger().error( + f"Failed to parse AI prediction after fallbacks", + artifact={'response_text': response_text}, + ) else: - get_logger().info(f"Successfully parsed AI prediction after fallbacks", - artifact={'response_text': response_text}) + get_logger().info( + f"Successfully parsed AI prediction after fallbacks", + artifact={'response_text': response_text}, + ) return data - -def try_fix_yaml(response_text: str, - keys_fix_yaml: List[str] = [], - first_key="", - last_key="",) -> dict: +def try_fix_yaml( + response_text: str, + keys_fix_yaml: List[str] = [], + first_key="", + last_key="", +) -> dict: response_text_lines = response_text.split('\n') - keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:'] + keys_yaml = [ + 'relevant line:', + 'suggestion content:', + 'relevant file:', + 'existing code:', + 'improved code:', + ] keys_yaml = keys_yaml + keys_fix_yaml # first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...' response_text_lines_copy = response_text_lines.copy() for i in range(0, len(response_text_lines_copy)): for key in keys_yaml: - if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]: - response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}', - f'{key} |\n ') + if ( + key in response_text_lines_copy[i] + and not '|' in response_text_lines_copy[i] + ): + response_text_lines_copy[i] = response_text_lines_copy[i].replace( + f'{key}', f'{key} |\n ' + ) try: data = yaml.safe_load('\n'.join(response_text_lines_copy)) get_logger().info(f"Successfully parsed AI prediction after adding |-\n") @@ -743,22 +863,26 @@ def try_fix_yaml(response_text: str, snippet_text = snippet.group() try: data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`')) - get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") + get_logger().info( + f"Successfully parsed AI prediction after extracting yaml snippet" + ) return data except: pass - # third fallback - try to remove leading and trailing curly brackets - response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') + response_text_copy = ( + response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') + ) try: data = yaml.safe_load(response_text_copy) - get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") + get_logger().info( + f"Successfully parsed AI prediction after removing curly brackets" + ) return data except: pass - # forth fallback - try to extract yaml snippet by 'first_key' and 'last_key' # note that 'last_key' can be in practice a key that is not the last key in the yaml snippet. # it just needs to be some inner key, so we can look for newlines after it @@ -767,13 +891,23 @@ def try_fix_yaml(response_text: str, if index_start == -1: index_start = response_text.find(f"{first_key}:") index_last_code = response_text.rfind(f"{last_key}:") - index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key + index_end = response_text.find( + "\n\n", index_last_code + ) # look for newlines after last_key if index_end == -1: index_end = len(response_text) - response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip() + response_text_copy = ( + response_text[index_start:index_end] + .strip() + .strip('```yaml') + .strip('`') + .strip() + ) try: data = yaml.safe_load(response_text_copy) - get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") + get_logger().info( + f"Successfully parsed AI prediction after extracting yaml snippet" + ) return data except: pass @@ -784,7 +918,9 @@ def try_fix_yaml(response_text: str, response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:] try: data = yaml.safe_load('\n'.join(response_text_lines_copy)) - get_logger().info(f"Successfully parsed AI prediction after removing leading '+'") + get_logger().info( + f"Successfully parsed AI prediction after removing leading '+'" + ) return data except: pass @@ -794,7 +930,9 @@ def try_fix_yaml(response_text: str, response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) try: data = yaml.safe_load(response_text_lines_tmp) - get_logger().info(f"Successfully parsed AI prediction after removing {i} lines") + get_logger().info( + f"Successfully parsed AI prediction after removing {i} lines" + ) return data except: pass @@ -820,11 +958,14 @@ def set_custom_labels(variables, git_provider=None): for k, v in labels.items(): description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'" # variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" - variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}" + variables[ + "custom_labels_class" + ] += f"\n {k.lower().replace(' ', '_')} = {description}" labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k counter += 1 variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict + def get_user_labels(current_labels: List[str] = None): """ Only keep labels that has been added by the user @@ -866,14 +1007,22 @@ def get_max_tokens(model): elif settings.config.custom_model_max_tokens > 0: max_tokens_model = settings.config.custom_model_max_tokens else: - raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens") + raise Exception( + f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens" + ) if settings.config.max_model_tokens and settings.config.max_model_tokens > 0: max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) return max_tokens_model -def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str: +def clip_tokens( + text: str, + max_tokens: int, + add_three_dots=True, + num_input_tokens=None, + delete_last_line=False, +) -> str: """ Clip the number of tokens in a string to a maximum number of tokens. @@ -909,14 +1058,15 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token clipped_text = clipped_text.rsplit('\n', 1)[0] if add_three_dots: clipped_text += "\n...(truncated)" - else: # if the text is empty - clipped_text = "" + else: # if the text is empty + clipped_text = "" return clipped_text except Exception as e: get_logger().warning(f"Failed to clip tokens: {e}") return text + def replace_code_tags(text): """ Replace odd instances of ` with and even instances of ` with @@ -928,15 +1078,16 @@ def replace_code_tags(text): return ''.join(parts) -def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], - relevant_file: str, - relevant_line_in_file: str, - absolute_position: int = None) -> Tuple[int, int]: +def find_line_number_of_relevant_line_in_file( + diff_files: List[FilePatchInfo], + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, +) -> Tuple[int, int]: position = -1 if absolute_position is None: absolute_position = -1 - re_hunk_header = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + re_hunk_header = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") if not diff_files: return position, absolute_position @@ -947,7 +1098,7 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], patch_lines = patch.splitlines() delta = 0 start1, size1, start2, size2 = 0, 0, 0, 0 - if absolute_position != -1: # matching absolute to relative + if absolute_position != -1: # matching absolute to relative for i, line in enumerate(patch_lines): # new hunk if line.startswith('@@'): @@ -965,12 +1116,12 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], break else: # try to find the line in the patch using difflib, with some margin of error - matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, - patch_lines, n=3, cutoff=0.93) + matches_difflib: list[str | Any] = difflib.get_close_matches( + relevant_line_in_file, patch_lines, n=3, cutoff=0.93 + ) if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): relevant_line_in_file = matches_difflib[0] - for i, line in enumerate(patch_lines): if line.startswith('@@'): delta = 0 @@ -1002,19 +1153,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], break return position, absolute_position + def get_rate_limit_status(github_token) -> dict: - GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" + GITHUB_API_URL = ( + get_settings(use_context=False) + .get("GITHUB.BASE_URL", "https://api.github.com") + .rstrip("/") + ) # "https://api.github.com" # GITHUB_API_URL = "https://api.github.com" RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit" HEADERS = { "Accept": "application/vnd.github.v3+json", - "Authorization": f"token {github_token}" + "Authorization": f"token {github_token}", } response = requests.get(RATE_LIMIT_URL, headers=HEADERS) try: rate_limit_info = response.json() - if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise + if ( + rate_limit_info.get('message') == 'Rate limiting is not enabled.' + ): # for github enterprise return {'resources': {}} response.raise_for_status() # Check for HTTP errors except: # retry @@ -1024,12 +1182,16 @@ def get_rate_limit_status(github_token) -> dict: return rate_limit_info -def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool: +def validate_rate_limit_github( + github_token, installation_id=None, threshold=0.1 +) -> bool: try: rate_limit_status = get_rate_limit_status(github_token) if installation_id: - get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}") - # validate that the rate limit is not exceeded + get_logger().debug( + f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}" + ) + # validate that the rate limit is not exceeded # validate that the rate limit is not exceeded for key, value in rate_limit_status['resources'].items(): if value['remaining'] < value['limit'] * threshold: @@ -1037,8 +1199,9 @@ def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1 return False return True except Exception as e: - get_logger().error(f"Error in rate limit {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error in rate limit {e}", artifact={"traceback": traceback.format_exc()} + ) return True @@ -1051,7 +1214,9 @@ def validate_and_await_rate_limit(github_token): get_logger().error(f"key: {key}, value: {value}") sleep_time_sec = value['reset'] - datetime.now().timestamp() sleep_time_hour = sleep_time_sec / 3600.0 - get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") + get_logger().error( + f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours" + ) if sleep_time_sec > 0: time.sleep(sleep_time_sec + 1) rate_limit_status = get_rate_limit_status(github_token) @@ -1068,22 +1233,39 @@ def github_action_output(output_data: dict, key_name: str): key_data = output_data.get(key_name, {}) with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: - print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh) + print( + f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", + file=fh, + ) except Exception as e: get_logger().error(f"Failed to write to GitHub Action output: {e}") return def show_relevant_configurations(relevant_section: str) -> str: - skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", - 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME'] + skip_keys = [ + 'ai_disclaimer', + 'ai_disclaimer_title', + 'ANALYTICS_FOLDER', + 'secret_provider', + "skip_keys", + "app_id", + "redirect", + 'trial_prefix_message', + 'no_eligible_message', + 'identity_provider', + 'ALLOWED_REPOS', + 'APP_NAME', + ] extra_skip_keys = get_settings().config.get('config.skip_keys', []) if extra_skip_keys: skip_keys.extend(extra_skip_keys) markdown_text = "" - markdown_text += "\n
\n
🛠️ 相关配置: \n\n" - markdown_text +="
以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n" + markdown_text += ( + "\n
\n
🛠️ 相关配置: \n\n" + ) + markdown_text += "
以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n" markdown_text += f"**[config**]\n```yaml\n\n" for key, value in get_settings().config.items(): if key in skip_keys: @@ -1099,6 +1281,7 @@ def show_relevant_configurations(relevant_section: str) -> str: markdown_text += "\n
\n" return markdown_text + def is_value_no(value): if not value: return True @@ -1122,7 +1305,7 @@ def string_to_uniform_number(s: str) -> float: # Convert the hash to an integer hash_int = int(hash_object.hexdigest(), 16) # Normalize the integer to the range [0, 1] - max_hash_int = 2 ** 256 - 1 + max_hash_int = 2**256 - 1 uniform_number = float(hash_int) / max_hash_int return uniform_number @@ -1131,7 +1314,9 @@ def process_description(description_full: str) -> Tuple[str, List]: if not description_full: return "", [] - description_split = description_full.split(PRDescriptionHeader.CHANGES_WALKTHROUGH.value) + description_split = description_full.split( + PRDescriptionHeader.CHANGES_WALKTHROUGH.value + ) base_description_str = description_split[0] changes_walkthrough_str = "" files = [] @@ -1167,45 +1352,58 @@ def process_description(description_full: str) -> Tuple[str, List]: pattern_back = r'
\s*(.*?)
(.*?).*?
\s*
\s*(.*?)\n\n\s*(.*?)
' res = re.search(pattern_back, file_data, re.DOTALL) if not res or res.lastindex != 4: - pattern_back = r'
\s*(.*?)\s*
(.*?).*?
\s*
\s*(.*?)\s*-\s*(.*?)\s*
' # looking for hypen ('- ') + pattern_back = r'
\s*(.*?)\s*
(.*?).*?
\s*
\s*(.*?)\s*-\s*(.*?)\s*
' # looking for hypen ('- ') res = re.search(pattern_back, file_data, re.DOTALL) if res and res.lastindex == 4: short_filename = res.group(1).strip() short_summary = res.group(2).strip() long_filename = res.group(3).strip() - long_summary = res.group(4).strip() - long_summary = long_summary.replace('
*', '\n*').replace('
','').replace('\n','
') + long_summary = res.group(4).strip() + long_summary = ( + long_summary.replace('
*', '\n*') + .replace('
', '') + .replace('\n', '
') + ) long_summary = h.handle(long_summary).strip() if long_summary.startswith('\\-'): long_summary = "* " + long_summary[2:] elif not long_summary.startswith('*'): long_summary = f"* {long_summary}" - files.append({ - 'short_file_name': short_filename, - 'full_file_name': long_filename, - 'short_summary': short_summary, - 'long_summary': long_summary - }) + files.append( + { + 'short_file_name': short_filename, + 'full_file_name': long_filename, + 'short_summary': short_summary, + 'long_summary': long_summary, + } + ) else: if '...' in file_data: - pass # PR with many files. some did not get analyzed + pass # PR with many files. some did not get analyzed else: - get_logger().error(f"Failed to parse description", artifact={'description': file_data}) + get_logger().error( + f"Failed to parse description", + artifact={'description': file_data}, + ) except Exception as e: - get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data}) - + get_logger().exception( + f"Failed to process description: {e}", + artifact={'description': file_data}, + ) except Exception as e: get_logger().exception(f"Failed to process description: {e}") return base_description_str, files + def get_version() -> str: # First check pyproject.toml if running directly out of repository if os.path.exists("pyproject.toml"): if sys.version_info >= (3, 11): import tomllib + with open("pyproject.toml", "rb") as f: data = tomllib.load(f) if "project" in data and "version" in data["project"]: @@ -1213,7 +1411,9 @@ def get_version() -> str: else: get_logger().warning("Version not found in pyproject.toml") else: - get_logger().warning("Unable to determine local version from pyproject.toml") + get_logger().warning( + "Unable to determine local version from pyproject.toml" + ) # Otherwise get the installed pip package version try: diff --git a/apps/utils/pr_agent/cli.py b/apps/utils/pr_agent/cli.py index ae72ec7..bc6e665 100644 --- a/apps/utils/pr_agent/cli.py +++ b/apps/utils/pr_agent/cli.py @@ -12,8 +12,9 @@ setup_logger(log_level) def set_parser(): - parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= - """\ + parser = argparse.ArgumentParser( + description='AI based pull request analyzer', + usage="""\ Usage: cli.py --pr-url= []. For example: - cli.py --pr_url=... review @@ -45,11 +46,20 @@ def set_parser(): Configuration: To edit any configuration parameter from 'configuration.toml', just add -config_path=. For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' - """) - parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}') - parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None) - parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None) - parser.add_argument('command', type=str, help='The', choices=commands, default='review') + """, + ) + parser.add_argument( + '--version', action='version', version=f'pr-agent {get_version()}' + ) + parser.add_argument( + '--pr_url', type=str, help='The URL of the PR to review', default=None + ) + parser.add_argument( + '--issue_url', type=str, help='The URL of the Issue to review', default=None + ) + parser.add_argument( + 'command', type=str, help='The', choices=commands, default='review' + ) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) return parser @@ -76,14 +86,24 @@ def run(inargs=None, args=None): async def inner(): if args.issue_url: - result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest)) + result = await asyncio.create_task( + PRAgent().handle_request(args.issue_url, [command] + args.rest) + ) else: - result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest)) + result = await asyncio.create_task( + PRAgent().handle_request(args.pr_url, [command] + args.rest) + ) if get_settings().litellm.get("enable_callbacks", False): # There may be additional events on the event queue from the run above. If there are give them time to complete. get_logger().debug("Waiting for event queue to complete") - await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()]) + await asyncio.wait( + [ + task + for task in asyncio.all_tasks() + if task is not asyncio.current_task() + ] + ) return result diff --git a/apps/utils/pr_agent/cli_pip.py b/apps/utils/pr_agent/cli_pip.py index 9604bf0..61e5b66 100644 --- a/apps/utils/pr_agent/cli_pip.py +++ b/apps/utils/pr_agent/cli_pip.py @@ -7,7 +7,9 @@ def main(): provider = "github" # GitHub provider user_token = "..." # GitHub user token openai_key = "..." # OpenAI key - pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809' + pr_url = ( + "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809' + ) command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"') # Setting the configurations diff --git a/apps/utils/pr_agent/config_loader.py b/apps/utils/pr_agent/config_loader.py index 9ae430c..8c1703b 100644 --- a/apps/utils/pr_agent/config_loader.py +++ b/apps/utils/pr_agent/config_loader.py @@ -11,26 +11,29 @@ current_dir = dirname(abspath(__file__)) global_settings = Dynaconf( envvar_prefix=False, merge_enabled=True, - settings_files=[join(current_dir, f) for f in [ - "settings/configuration.toml", - "settings/ignore.toml", - "settings/language_extensions.toml", - "settings/pr_reviewer_prompts.toml", - "settings/pr_questions_prompts.toml", - "settings/pr_line_questions_prompts.toml", - "settings/pr_description_prompts.toml", - "settings/pr_code_suggestions_prompts.toml", - "settings/pr_code_suggestions_reflect_prompts.toml", - "settings/pr_sort_code_suggestions_prompts.toml", - "settings/pr_information_from_user_prompts.toml", - "settings/pr_update_changelog_prompts.toml", - "settings/pr_custom_labels.toml", - "settings/pr_add_docs.toml", - "settings/custom_labels.toml", - "settings/pr_help_prompts.toml", - "settings/.secrets.toml", - "settings_prod/.secrets.toml", - ]] + settings_files=[ + join(current_dir, f) + for f in [ + "settings/configuration.toml", + "settings/ignore.toml", + "settings/language_extensions.toml", + "settings/pr_reviewer_prompts.toml", + "settings/pr_questions_prompts.toml", + "settings/pr_line_questions_prompts.toml", + "settings/pr_description_prompts.toml", + "settings/pr_code_suggestions_prompts.toml", + "settings/pr_code_suggestions_reflect_prompts.toml", + "settings/pr_sort_code_suggestions_prompts.toml", + "settings/pr_information_from_user_prompts.toml", + "settings/pr_update_changelog_prompts.toml", + "settings/pr_custom_labels.toml", + "settings/pr_add_docs.toml", + "settings/custom_labels.toml", + "settings/pr_help_prompts.toml", + "settings/.secrets.toml", + "settings_prod/.secrets.toml", + ] + ], ) diff --git a/apps/utils/pr_agent/git_providers/__init__.py b/apps/utils/pr_agent/git_providers/__init__.py index 1952deb..a4646c0 100644 --- a/apps/utils/pr_agent/git_providers/__init__.py +++ b/apps/utils/pr_agent/git_providers/__init__.py @@ -3,8 +3,9 @@ from starlette_context import context from utils.pr_agent.config_loader import get_settings from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider -from utils.pr_agent.git_providers.bitbucket_server_provider import \ - BitbucketServerProvider +from utils.pr_agent.git_providers.bitbucket_server_provider import ( + BitbucketServerProvider, +) from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider from utils.pr_agent.git_providers.gerrit_provider import GerritProvider from utils.pr_agent.git_providers.git_provider import GitProvider @@ -28,7 +29,9 @@ def get_git_provider(): try: provider_id = get_settings().config.git_provider except AttributeError as e: - raise ValueError("git_provider is a required attribute in the configuration file") from e + raise ValueError( + "git_provider is a required attribute in the configuration file" + ) from e if provider_id not in _GIT_PROVIDERS: raise ValueError(f"Unknown git provider: {provider_id}") return _GIT_PROVIDERS[provider_id] diff --git a/apps/utils/pr_agent/git_providers/azuredevops_provider.py b/apps/utils/pr_agent/git_providers/azuredevops_provider.py index cfe24b9..a2abbb4 100644 --- a/apps/utils/pr_agent/git_providers/azuredevops_provider.py +++ b/apps/utils/pr_agent/git_providers/azuredevops_provider.py @@ -6,25 +6,33 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo from ..algo.file_filter import filter_ignored from ..algo.language_handler import is_valid_file -from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file, - load_large_diff) +from ..algo.utils import ( + PRDescriptionHeader, + find_line_number_of_relevant_line_in_file, + load_large_diff, +) from ..config_loader import get_settings from ..log import get_logger from .git_provider import GitProvider AZURE_DEVOPS_AVAILABLE = True ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default" -MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1 +MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000 - 1 try: # noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences from azure.devops.connection import Connection + # noinspection PyUnresolvedReferences - from azure.devops.v7_1.git.models import (Comment, CommentThread, - GitPullRequest, - GitPullRequestIterationChanges, - GitVersionDescriptor) + from azure.devops.v7_1.git.models import ( + Comment, + CommentThread, + GitPullRequest, + GitPullRequestIterationChanges, + GitVersionDescriptor, + ) + # noinspection PyUnresolvedReferences from azure.identity import DefaultAzureCredential from msrest.authentication import BasicAuthentication @@ -33,9 +41,8 @@ except ImportError: class AzureDevopsProvider(GitProvider): - def __init__( - self, pr_url: Optional[str] = None, incremental: Optional[bool] = False + self, pr_url: Optional[str] = None, incremental: Optional[bool] = False ): if not AZURE_DEVOPS_AVAILABLE: raise ImportError( @@ -67,13 +74,16 @@ class AzureDevopsProvider(GitProvider): if not relevant_lines_start or relevant_lines_start == -1: get_logger().warning( - f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") + f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end < relevant_lines_start: - get_logger().warning(f"Failed to publish code suggestion, " - f"relevant_lines_end is {relevant_lines_end} and " - f"relevant_lines_start is {relevant_lines_start}") + get_logger().warning( + f"Failed to publish code suggestion, " + f"relevant_lines_end is {relevant_lines_end} and " + f"relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end > relevant_lines_start: @@ -98,30 +108,32 @@ class AzureDevopsProvider(GitProvider): for post_parameters in post_parameters_list: try: comment = Comment(content=post_parameters["body"], comment_type=1) - thread = CommentThread(comments=[comment], - thread_context={ - "filePath": post_parameters["path"], - "rightFileStart": { - "line": post_parameters["start_line"], - "offset": 1, - }, - "rightFileEnd": { - "line": post_parameters["line"], - "offset": 1, - }, - }) + thread = CommentThread( + comments=[comment], + thread_context={ + "filePath": post_parameters["path"], + "rightFileStart": { + "line": post_parameters["start_line"], + "offset": 1, + }, + "rightFileEnd": { + "line": post_parameters["line"], + "offset": 1, + }, + }, + ) self.azure_devops_client.create_thread( comment_thread=thread, project=self.workspace_slug, repository_id=self.repo_slug, - pull_request_id=self.pr_num + pull_request_id=self.pr_num, ) except Exception as e: - get_logger().warning(f"Azure failed to publish code suggestion, error: {e}") + get_logger().warning( + f"Azure failed to publish code suggestion, error: {e}" + ) return True - - def get_pr_description_full(self) -> str: return self.pr.description @@ -204,9 +216,9 @@ class AzureDevopsProvider(GitProvider): def get_files(self): files = [] for i in self.azure_devops_client.get_pull_request_commits( - project=self.workspace_slug, - repository_id=self.repo_slug, - pull_request_id=self.pr_num, + project=self.workspace_slug, + repository_id=self.repo_slug, + pull_request_id=self.pr_num, ): changes_obj = self.azure_devops_client.get_changes( project=self.workspace_slug, @@ -220,7 +232,6 @@ class AzureDevopsProvider(GitProvider): def get_diff_files(self) -> list[FilePatchInfo]: try: - if self.diff_files: return self.diff_files @@ -231,18 +242,20 @@ class AzureDevopsProvider(GitProvider): iterations = self.azure_devops_client.get_pull_request_iterations( repository_id=self.repo_slug, pull_request_id=self.pr_num, - project=self.workspace_slug + project=self.workspace_slug, ) changes = None if iterations: - iteration_id = iterations[-1].id # Get the last iteration (most recent changes) + iteration_id = iterations[ + -1 + ].id # Get the last iteration (most recent changes) # Get changes for the iteration changes = self.azure_devops_client.get_pull_request_iteration_changes( repository_id=self.repo_slug, pull_request_id=self.pr_num, iteration_id=iteration_id, - project=self.workspace_slug + project=self.workspace_slug, ) diff_files = [] diffs = [] @@ -253,7 +266,9 @@ class AzureDevopsProvider(GitProvider): path = item.get('path', None) if path: diffs.append(path) - diff_types[path] = change.additional_properties.get('changeType', 'Unknown') + diff_types[path] = change.additional_properties.get( + 'changeType', 'Unknown' + ) # wrong implementation - gets all the files that were changed in any commit in the PR # commits = self.azure_devops_client.get_pull_request_commits( @@ -284,9 +299,13 @@ class AzureDevopsProvider(GitProvider): diffs = filter_ignored(diffs_original, 'azure') if diffs_original != diffs: try: - get_logger().info(f"Filtered out [ignore] files for pull request:", extra= - {"files": diffs_original, # diffs is just a list of names - "filtered_files": diffs}) + get_logger().info( + f"Filtered out [ignore] files for pull request:", + extra={ + "files": diffs_original, # diffs is just a list of names + "filtered_files": diffs, + }, + ) except Exception: pass @@ -311,7 +330,10 @@ class AzureDevopsProvider(GitProvider): new_file_content_str = new_file_content_str.content except Exception as error: - get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error) + get_logger().error( + f"Failed to retrieve new file content of {file} at version {version}", + error=error, + ) # get_logger().error( # "Failed to retrieve new file content of %s at version %s. Error: %s", # file, @@ -325,7 +347,9 @@ class AzureDevopsProvider(GitProvider): edit_type = EDIT_TYPE.ADDED elif diff_types[file] == "delete": edit_type = EDIT_TYPE.DELETED - elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename` + elif ( + "rename" in diff_types[file] + ): # diff_type can be `rename` | `edit, rename` edit_type = EDIT_TYPE.RENAMED version = GitVersionDescriptor( @@ -345,17 +369,27 @@ class AzureDevopsProvider(GitProvider): ) original_file_content_str = original_file_content_str.content except Exception as error: - get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error) + get_logger().error( + f"Failed to retrieve original file content of {file} at version {version}", + error=error, + ) original_file_content_str = "" patch = load_large_diff( - file, new_file_content_str, original_file_content_str, show_warning=False + file, + new_file_content_str, + original_file_content_str, + show_warning=False, ).rstrip() # count number of lines added and removed patch_lines = patch.splitlines(keepends=True) - num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) - num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) + num_plus_lines = len( + [line for line in patch_lines if line.startswith('+')] + ) + num_minus_lines = len( + [line for line in patch_lines if line.startswith('-')] + ) diff_files.append( FilePatchInfo( @@ -376,27 +410,35 @@ class AzureDevopsProvider(GitProvider): get_logger().exception(f"Failed to get diff files, error: {e}") return [] - def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None): + def publish_comment( + self, pr_comment: str, is_temporary: bool = False, thread_context=None + ): if is_temporary and not get_settings().config.publish_output_progress: - get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") + get_logger().debug( + f"Skipping publish_comment for temporary comment: {pr_comment}" + ) return None comment = Comment(content=pr_comment) - thread = CommentThread(comments=[comment], thread_context=thread_context, status=5) + thread = CommentThread( + comments=[comment], thread_context=thread_context, status=5 + ) thread_response = self.azure_devops_client.create_thread( comment_thread=thread, project=self.workspace_slug, repository_id=self.repo_slug, pull_request_id=self.pr_num, ) - response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id} + response = { + "thread_id": thread_response.id, + "comment_id": thread_response.comments[0].id, + } if is_temporary: self.temp_comments.append(response) return response def publish_description(self, pr_title: str, pr_body: str): if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH: - - usage_guide_text='
✨ Describe tool usage guide:
' + usage_guide_text = '
✨ Describe tool usage guide:
' ind = pr_body.find(usage_guide_text) if ind != -1: pr_body = pr_body[:ind] @@ -409,7 +451,10 @@ class AzureDevopsProvider(GitProvider): if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH: trunction_message = " ... (description truncated due to length limit)" - pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message + pr_body = ( + pr_body[: MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + + trunction_message + ) get_logger().warning("PR description was truncated due to length limit") try: updated_pr = GitPullRequest() @@ -433,50 +478,79 @@ class AzureDevopsProvider(GitProvider): except Exception as e: get_logger().exception(f"Failed to remove temp comments, error: {e}") - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): - self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): + self.publish_inline_comments( + [self.create_inline_comment(body, relevant_file, relevant_line_in_file)] + ) - - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): - position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(), - relevant_file.strip('`'), - relevant_line_in_file, - absolute_position) + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.get_diff_files(), + relevant_file.strip('`'), + relevant_line_in_file, + absolute_position, + ) if position == -1: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + get_logger().info( + f"Could not find position for {relevant_file} {relevant_line_in_file}" + ) subject_type = "FILE" else: subject_type = "LINE" path = relevant_file.strip() - return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {} + return ( + dict( + body=body, + path=path, + position=position, + absolute_position=absolute_position, + ) + if subject_type == "LINE" + else {} + ) - def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False): - overall_success = True - for comment in comments: - try: - self.publish_comment(comment["body"], - thread_context={ - "filePath": comment["path"], - "rightFileStart": { - "line": comment["absolute_position"], - "offset": comment["position"], - }, - "rightFileEnd": { - "line": comment["absolute_position"], - "offset": comment["position"], - }, - }) - if get_settings().config.verbosity_level >= 2: - get_logger().info( - f"Published code suggestion on {self.pr_num} at {comment['path']}" - ) - except Exception as e: - if get_settings().config.verbosity_level >= 2: - get_logger().error(f"Failed to publish code suggestion, error: {e}") - overall_success = False - return overall_success + def publish_inline_comments( + self, comments: list[dict], disable_fallback: bool = False + ): + overall_success = True + for comment in comments: + try: + self.publish_comment( + comment["body"], + thread_context={ + "filePath": comment["path"], + "rightFileStart": { + "line": comment["absolute_position"], + "offset": comment["position"], + }, + "rightFileEnd": { + "line": comment["absolute_position"], + "offset": comment["position"], + }, + }, + ) + if get_settings().config.verbosity_level >= 2: + get_logger().info( + f"Published code suggestion on {self.pr_num} at {comment['path']}" + ) + except Exception as e: + if get_settings().config.verbosity_level >= 2: + get_logger().error(f"Failed to publish code suggestion, error: {e}") + overall_success = False + return overall_success def get_title(self): return self.pr.title @@ -521,7 +595,11 @@ class AzureDevopsProvider(GitProvider): return 0 def get_issue_comments(self): - threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug) + threads = self.azure_devops_client.get_threads( + repository_id=self.repo_slug, + pull_request_id=self.pr_num, + project=self.workspace_slug, + ) threads.reverse() comment_list = [] for thread in threads: @@ -532,7 +610,9 @@ class AzureDevopsProvider(GitProvider): comment_list.append(comment) return comment_list - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: return True def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: @@ -547,16 +627,22 @@ class AzureDevopsProvider(GitProvider): raise ValueError( "The provided URL does not appear to be a Azure DevOps PR URL" ) - if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1" + if ( + len(path_parts) == 6 + ): # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1" workspace_slug = path_parts[1] repo_slug = path_parts[3] pr_number = int(path_parts[5]) - elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1' + elif ( + len(path_parts) == 5 + ): # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1' workspace_slug = path_parts[0] repo_slug = path_parts[2] pr_number = int(path_parts[4]) else: - raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL") + raise ValueError( + "The provided URL does not appear to be a Azure DevOps PR URL" + ) return workspace_slug, repo_slug, pr_number @@ -575,12 +661,16 @@ class AzureDevopsProvider(GitProvider): # try to use azure default credentials # see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python # for usage and env var configuration of user-assigned managed identity, local machine auth etc. - get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.") + get_logger().info( + "No PAT found in settings, trying to use Azure Default Credentials." + ) credentials = DefaultAzureCredential() accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID) auth_token = accessToken.token except Exception as e: - get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}") + get_logger().error( + f"No PAT found in settings, and Azure Default Authentication failed, error: {e}" + ) raise credentials = BasicAuthentication("", auth_token) diff --git a/apps/utils/pr_agent/git_providers/bitbucket_provider.py b/apps/utils/pr_agent/git_providers/bitbucket_provider.py index deae293..53d26b3 100644 --- a/apps/utils/pr_agent/git_providers/bitbucket_provider.py +++ b/apps/utils/pr_agent/git_providers/bitbucket_provider.py @@ -52,13 +52,19 @@ class BitbucketProvider(GitProvider): self.git_files = None if pr_url: self.set_pr(pr_url) - self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"] - self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href'] + self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"][ + "comments" + ]["href"] + self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"][ + 'self' + ]['href'] def get_repo_settings(self): try: - url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" - f"{self.pr.destination_branch}/.pr_agent.toml") + url = ( + f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" + f"{self.pr.destination_branch}/.pr_agent.toml" + ) response = requests.request("GET", url, headers=self.headers) if response.status_code == 404: # not found return "" @@ -74,20 +80,27 @@ class BitbucketProvider(GitProvider): post_parameters_list = [] for suggestion in code_suggestions: body = suggestion["body"] - original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code + original_suggestion = suggestion.get( + 'original_suggestion', None + ) # needed for diff code if original_suggestion: try: existing_code = original_suggestion['existing_code'].rstrip() + "\n" improved_code = original_suggestion['improved_code'].rstrip() + "\n" - diff = difflib.unified_diff(existing_code.split('\n'), - improved_code.split('\n'), n=999) + diff = difflib.unified_diff( + existing_code.split('\n'), improved_code.split('\n'), n=999 + ) patch_orig = "\n".join(diff) patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" # replace ```suggestion ... ``` with diff_code, using regex: - body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) + body = re.sub( + r'```suggestion.*?```', diff_code, body, flags=re.DOTALL + ) except Exception as e: - get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}") + get_logger().exception( + f"Bitbucket failed to get diff code for publishing, error: {e}" + ) continue relevant_file = suggestion["relevant_file"] @@ -129,15 +142,22 @@ class BitbucketProvider(GitProvider): self.publish_inline_comments(post_parameters_list) return True except Exception as e: - get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}") + get_logger().error( + f"Bitbucket failed to publish code suggestion, error: {e}" + ) return False def publish_file_comments(self, file_comments: list) -> bool: pass def is_supported(self, capability: str) -> bool: - if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown', - 'publish_file_comments']: + if capability in [ + 'get_issue_comments', + 'publish_inline_comments', + 'get_labels', + 'gfm_markdown', + 'publish_file_comments', + ]: return False return True @@ -169,12 +189,14 @@ class BitbucketProvider(GitProvider): names_original = [d.new.path for d in diffs_original] names_kept = [d.new.path for d in diffs] names_filtered = list(set(names_original) - set(names_kept)) - get_logger().info(f"Filtered out [ignore] files for PR", extra={ - 'original_files': names_original, - 'names_kept': names_kept, - 'names_filtered': names_filtered - - }) + get_logger().info( + f"Filtered out [ignore] files for PR", + extra={ + 'original_files': names_original, + 'names_kept': names_kept, + 'names_filtered': names_filtered, + }, + ) except Exception as e: pass @@ -189,20 +211,32 @@ class BitbucketProvider(GitProvider): for encoding in encodings_to_try: try: pr_patches = self.pr.diff(encoding=encoding) - get_logger().info(f"Successfully decoded PR patch with encoding {encoding}") + get_logger().info( + f"Successfully decoded PR patch with encoding {encoding}" + ) break except UnicodeDecodeError: continue if pr_patches is None: - raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}") + raise ValueError( + f"Failed to decode PR patch with encodings {encodings_to_try}" + ) - diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()] + diff_split = [ + "diff --git" + x for x in pr_patches.split("diff --git") if x.strip() + ] # filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs' if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split): - diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs] + diff_split = [ + diff_split[i] + for i in range(len(diff_split)) + if diffs_original[i] in diffs + ] if len(diff_split) != len(diffs): - get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts") + get_logger().error( + f"Error - failed to split the diff into {len(diffs)} parts" + ) return [] # bitbucket diff has a header for each file, we need to remove it: # "diff --git filename @@ -213,22 +247,34 @@ class BitbucketProvider(GitProvider): # @@ -... @@" for i, _ in enumerate(diff_split): diff_split_lines = diff_split[i].splitlines() - if (len(diff_split_lines) >= 6) and \ - ((diff_split_lines[2].startswith("---") and - diff_split_lines[3].startswith("+++") and - diff_split_lines[4].startswith("@@")) or - (diff_split_lines[3].startswith("---") and # new or deleted file - diff_split_lines[4].startswith("+++") and - diff_split_lines[5].startswith("@@"))): + if (len(diff_split_lines) >= 6) and ( + ( + diff_split_lines[2].startswith("---") + and diff_split_lines[3].startswith("+++") + and diff_split_lines[4].startswith("@@") + ) + or ( + diff_split_lines[3].startswith("---") + and diff_split_lines[4].startswith("+++") # new or deleted file + and diff_split_lines[5].startswith("@@") + ) + ): diff_split[i] = "\n".join(diff_split_lines[4:]) else: - if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0: + if ( + diffs[i].data.get('lines_added', 0) == 0 + and diffs[i].data.get('lines_removed', 0) == 0 + ): diff_split[i] = "" elif len(diff_split_lines) <= 3: diff_split[i] = "" - get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}") + get_logger().info( + f"Disregarding empty diff for file {_gef_filename(diffs[i])}" + ) else: - get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}") + get_logger().warning( + f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}" + ) diff_split[i] = "" invalid_files_names = [] @@ -246,24 +292,32 @@ class BitbucketProvider(GitProvider): if get_settings().get("bitbucket_app.avoid_full_files", False): original_file_content_str = "" new_file_content_str = "" - elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls + elif ( + counter_valid < MAX_FILES_ALLOWED_FULL // 2 + ): # factor 2 because bitbucket has limited API calls if diff.old.get_data("links"): original_file_content_str = self._get_pr_file_content( - diff.old.get_data("links")['self']['href']) + diff.old.get_data("links")['self']['href'] + ) else: original_file_content_str = "" if diff.new.get_data("links"): - new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href']) + new_file_content_str = self._get_pr_file_content( + diff.new.get_data("links")['self']['href'] + ) else: new_file_content_str = "" else: if counter_valid == MAX_FILES_ALLOWED_FULL // 2: get_logger().info( - f"Bitbucket too many files in PR, will avoid loading full content for rest of files") + f"Bitbucket too many files in PR, will avoid loading full content for rest of files" + ) original_file_content_str = "" new_file_content_str = "" except Exception as e: - get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}") + get_logger().exception( + f"Error - bitbucket failed to get file content, error: {e}" + ) original_file_content_str = "" new_file_content_str = "" @@ -285,7 +339,9 @@ class BitbucketProvider(GitProvider): diff_files.append(file_patch_canonic_structure) if invalid_files_names: - get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}") + get_logger().info( + f"Disregarding files with invalid extensions:\n{invalid_files_names}" + ) self.diff_files = diff_files return diff_files @@ -296,11 +352,14 @@ class BitbucketProvider(GitProvider): def get_comment_url(self, comment): return comment.data['links']['html']['href'] - def publish_persistent_comment(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): + def publish_persistent_comment( + self, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + ): try: for comment in self.pr.comments(): body = comment.raw @@ -309,15 +368,20 @@ class BitbucketProvider(GitProvider): comment_url = self.get_comment_url(comment) if update_header: updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n" - pr_comment_updated = pr_comment.replace(initial_header, updated_header) + pr_comment_updated = pr_comment.replace( + initial_header, updated_header + ) else: pr_comment_updated = pr_comment - get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message") + get_logger().info( + f"Persistent mode - updating comment {comment_url} to latest {name} message" + ) d = {"content": {"raw": pr_comment_updated}} response = comment._update_data(comment.put(None, data=d)) if final_update_message: self.publish_comment( - f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}") + f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}" + ) return except Exception as e: get_logger().exception(f"Failed to update persistent review, error: {e}") @@ -326,7 +390,9 @@ class BitbucketProvider(GitProvider): def publish_comment(self, pr_comment: str, is_temporary: bool = False): if is_temporary and not get_settings().config.publish_output_progress: - get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") + get_logger().debug( + f"Skipping publish_comment for temporary comment: {pr_comment}" + ) return None pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length) comment = self.pr.comment(pr_comment) @@ -355,39 +421,58 @@ class BitbucketProvider(GitProvider): get_logger().exception(f"Failed to remove comment, error: {e}") # function to create_inline_comment - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): body = self.limit_output_characters(body, self.max_comment_length) - position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(), - relevant_file.strip('`'), - relevant_line_in_file, - absolute_position) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.get_diff_files(), + relevant_file.strip('`'), + relevant_line_in_file, + absolute_position, + ) if position == -1: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + get_logger().info( + f"Could not find position for {relevant_file} {relevant_line_in_file}" + ) subject_type = "FILE" else: subject_type = "LINE" path = relevant_file.strip() - return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {} + return ( + dict(body=body, path=path, position=absolute_position) + if subject_type == "LINE" + else {} + ) - def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None): + def publish_inline_comment( + self, comment: str, from_line: int, file: str, original_suggestion=None + ): comment = self.limit_output_characters(comment, self.max_comment_length) - payload = json.dumps({ - "content": { - "raw": comment, - }, - "inline": { - "to": from_line, - "path": file - }, - }) + payload = json.dumps( + { + "content": { + "raw": comment, + }, + "inline": {"to": from_line, "path": file}, + } + ) response = requests.request( "POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers ) return response - def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: + def get_line_link( + self, + relevant_file: str, + relevant_line_start: int, + relevant_line_end: int = None, + ) -> str: if relevant_line_start == -1: link = f"{self.pr_url}/#L{relevant_file}" else: @@ -402,8 +487,9 @@ class BitbucketProvider(GitProvider): return "" diff_files = self.get_diff_files() - position, absolute_position = find_line_number_of_relevant_line_in_file \ - (diff_files, relevant_file, relevant_line_str) + position, absolute_position = find_line_number_of_relevant_line_in_file( + diff_files, relevant_file, relevant_line_str + ) if absolute_position != -1 and self.pr_url: link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}" @@ -417,12 +503,18 @@ class BitbucketProvider(GitProvider): def publish_inline_comments(self, comments: list[dict]): for comment in comments: if 'position' in comment: - self.publish_inline_comment(comment['body'], comment['position'], comment['path']) + self.publish_inline_comment( + comment['body'], comment['position'], comment['path'] + ) elif 'start_line' in comment: # multi-line comment # note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452 - self.publish_inline_comment(comment['body'], comment['start_line'], comment['path']) + self.publish_inline_comment( + comment['body'], comment['start_line'], comment['path'] + ) elif 'line' in comment: # single-line comment - self.publish_inline_comment(comment['body'], comment['line'], comment['path']) + self.publish_inline_comment( + comment['body'], comment['line'], comment['path'] + ) else: get_logger().error(f"Could not publish inline comment {comment}") @@ -450,7 +542,9 @@ class BitbucketProvider(GitProvider): "Bitbucket provider does not support issue comments yet" ) - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: return True def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: @@ -495,8 +589,10 @@ class BitbucketProvider(GitProvider): branch = self.pr.data["source"]["commit"]["hash"] elif branch == self.pr.destination_branch: branch = self.pr.data["destination"]["commit"]["hash"] - url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" - f"{branch}/{file_path}") + url = ( + f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" + f"{branch}/{file_path}" + ) response = requests.request("GET", url, headers=self.headers) if response.status_code == 404: # not found return "" @@ -505,23 +601,28 @@ class BitbucketProvider(GitProvider): except Exception: return "" - def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None: - url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/") + def create_or_update_pr_file( + self, file_path: str, branch: str, contents="", message="" + ) -> None: + url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" if not message: if contents: message = f"Update {file_path}" else: message = f"Create {file_path}" files = {file_path: contents} - data = { - "message": message, - "branch": branch - } - headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {} + data = {"message": message, "branch": branch} + headers = ( + {'Authorization': self.headers['Authorization']} + if 'Authorization' in self.headers + else {} + ) try: requests.request("POST", url, headers=headers, data=data, files=files) except Exception: - get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}") + get_logger().exception( + f"Failed to create empty file {file_path} in branch {branch}" + ) def _get_pr_file_content(self, remote_link: str): try: @@ -538,16 +639,19 @@ class BitbucketProvider(GitProvider): # bitbucket does not support labels def publish_description(self, pr_title: str, description: str): - payload = json.dumps({ - "description": description, - "title": pr_title + payload = json.dumps({"description": description, "title": pr_title}) - }) - - response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload) + response = requests.request( + "PUT", + self.bitbucket_pull_request_api_url, + headers=self.headers, + data=payload, + ) try: if response.status_code != 200: - get_logger().info(f"Failed to update description, error code: {response.status_code}") + get_logger().info( + f"Failed to update description, error code: {response.status_code}" + ) except: pass return response diff --git a/apps/utils/pr_agent/git_providers/bitbucket_server_provider.py b/apps/utils/pr_agent/git_providers/bitbucket_server_provider.py index 22f85e5..149bdad 100644 --- a/apps/utils/pr_agent/git_providers/bitbucket_server_provider.py +++ b/apps/utils/pr_agent/git_providers/bitbucket_server_provider.py @@ -11,8 +11,7 @@ from requests.exceptions import HTTPError from ..algo.git_patch_processing import decode_if_bytes from ..algo.language_handler import is_valid_file from ..algo.types import EDIT_TYPE, FilePatchInfo -from ..algo.utils import (find_line_number_of_relevant_line_in_file, - load_large_diff) +from ..algo.utils import find_line_number_of_relevant_line_in_file, load_large_diff from ..config_loader import get_settings from ..log import get_logger from .git_provider import GitProvider @@ -20,8 +19,10 @@ from .git_provider import GitProvider class BitbucketServerProvider(GitProvider): def __init__( - self, pr_url: Optional[str] = None, incremental: Optional[bool] = False, - bitbucket_client: Optional[Bitbucket] = None, + self, + pr_url: Optional[str] = None, + incremental: Optional[bool] = False, + bitbucket_client: Optional[Bitbucket] = None, ): self.bitbucket_server_url = None self.workspace_slug = None @@ -36,11 +37,16 @@ class BitbucketServerProvider(GitProvider): self.bitbucket_pull_request_api_url = pr_url self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url) - self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url, - token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", - None)) + self.bitbucket_client = bitbucket_client or Bitbucket( + url=self.bitbucket_server_url, + token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", None), + ) try: - self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version')) + self.bitbucket_api_version = parse_version( + self.bitbucket_client.get("rest/api/1.0/application-properties").get( + 'version' + ) + ) except Exception: self.bitbucket_api_version = None @@ -49,7 +55,12 @@ class BitbucketServerProvider(GitProvider): def get_repo_settings(self): try: - content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch()) + content = self.bitbucket_client.get_content_of_file( + self.workspace_slug, + self.repo_slug, + ".pr_agent.toml", + self.get_pr_branch(), + ) return content except Exception as e: @@ -70,20 +81,27 @@ class BitbucketServerProvider(GitProvider): post_parameters_list = [] for suggestion in code_suggestions: body = suggestion["body"] - original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code + original_suggestion = suggestion.get( + 'original_suggestion', None + ) # needed for diff code if original_suggestion: try: existing_code = original_suggestion['existing_code'].rstrip() + "\n" improved_code = original_suggestion['improved_code'].rstrip() + "\n" - diff = difflib.unified_diff(existing_code.split('\n'), - improved_code.split('\n'), n=999) + diff = difflib.unified_diff( + existing_code.split('\n'), improved_code.split('\n'), n=999 + ) patch_orig = "\n".join(diff) patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" # replace ```suggestion ... ``` with diff_code, using regex: - body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) + body = re.sub( + r'```suggestion.*?```', diff_code, body, flags=re.DOTALL + ) except Exception as e: - get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}") + get_logger().exception( + f"Bitbucket failed to get diff code for publishing, error: {e}" + ) continue relevant_file = suggestion["relevant_file"] relevant_lines_start = suggestion["relevant_lines_start"] @@ -134,7 +152,12 @@ class BitbucketServerProvider(GitProvider): pass def is_supported(self, capability: str) -> bool: - if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']: + if capability in [ + 'get_issue_comments', + 'get_labels', + 'gfm_markdown', + 'publish_file_comments', + ]: return False return True @@ -145,23 +168,28 @@ class BitbucketServerProvider(GitProvider): def get_file(self, path: str, commit_id: str): file_content = "" try: - file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug, - self.repo_slug, - path, - commit_id) + file_content = self.bitbucket_client.get_content_of_file( + self.workspace_slug, self.repo_slug, path, commit_id + ) except HTTPError as e: get_logger().debug(f"File {path} not found at commit id: {commit_id}") return file_content def get_files(self): - changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num) + changes = self.bitbucket_client.get_pull_requests_changes( + self.workspace_slug, self.repo_slug, self.pr_num + ) diffstat = [change["path"]['toString'] for change in changes] return diffstat - #gets the best common ancestor: https://git-scm.com/docs/git-merge-base + # gets the best common ancestor: https://git-scm.com/docs/git-merge-base @staticmethod - def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str: - destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor} + def get_best_common_ancestor( + source_commits_list, destination_commits_list, guaranteed_common_ancestor + ) -> str: + destination_commit_hashes = { + commit['id'] for commit in destination_commits_list + } | {guaranteed_common_ancestor} for commit in source_commits_list: for parent_commit in commit['parents']: @@ -177,37 +205,55 @@ class BitbucketServerProvider(GitProvider): head_sha = self.pr.fromRef['latestCommit'] # if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation - if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"): + if ( + self.bitbucket_api_version is not None + and self.bitbucket_api_version >= parse_version("8.16") + ): try: base_sha = self.bitbucket_client.get(self._get_merge_base())['id'] except Exception as e: - get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}") + get_logger().error( + f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}" + ) raise e else: - source_commits_list = list(self.bitbucket_client.get_pull_requests_commits( - self.workspace_slug, - self.repo_slug, - self.pr_num - )) + source_commits_list = list( + self.bitbucket_client.get_pull_requests_commits( + self.workspace_slug, self.repo_slug, self.pr_num + ) + ) # if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor base_sha = source_commits_list[-1]['parents'][0]['id'] # if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha - if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"): + if ( + self.bitbucket_api_version is not None + and self.bitbucket_api_version >= parse_version("7.0") + ): try: destination_commits = list( - self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha, - self.pr.toRef['latestCommit'])) - base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha) + self.bitbucket_client.get_commits( + self.workspace_slug, + self.repo_slug, + base_sha, + self.pr.toRef['latestCommit'], + ) + ) + base_sha = self.get_best_common_ancestor( + source_commits_list, destination_commits, base_sha + ) except Exception as e: get_logger().error( - f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}") + f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}" + ) raise e diff_files = [] original_file_content_str = "" new_file_content_str = "" - changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num) + changes = self.bitbucket_client.get_pull_requests_changes( + self.workspace_slug, self.repo_slug, self.pr_num + ) for change in changes: file_path = change['path']['toString'] if not is_valid_file(file_path.split("/")[-1]): @@ -224,17 +270,26 @@ class BitbucketServerProvider(GitProvider): edit_type = EDIT_TYPE.DELETED new_file_content_str = "" original_file_content_str = self.get_file(file_path, base_sha) - original_file_content_str = decode_if_bytes(original_file_content_str) + original_file_content_str = decode_if_bytes( + original_file_content_str + ) case 'RENAME': edit_type = EDIT_TYPE.RENAMED case _: edit_type = EDIT_TYPE.MODIFIED original_file_content_str = self.get_file(file_path, base_sha) - original_file_content_str = decode_if_bytes(original_file_content_str) + original_file_content_str = decode_if_bytes( + original_file_content_str + ) new_file_content_str = self.get_file(file_path, head_sha) new_file_content_str = decode_if_bytes(new_file_content_str) - patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False) + patch = load_large_diff( + file_path, + new_file_content_str, + original_file_content_str, + show_warning=False, + ) diff_files.append( FilePatchInfo( @@ -251,7 +306,9 @@ class BitbucketServerProvider(GitProvider): def publish_comment(self, pr_comment: str, is_temporary: bool = False): if not is_temporary: - self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment) + self.bitbucket_client.add_pull_request_comment( + self.workspace_slug, self.repo_slug, self.pr_num, pr_comment + ) def remove_initial_comment(self): try: @@ -264,25 +321,37 @@ class BitbucketServerProvider(GitProvider): pass # function to create_inline_comment - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): - + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): position, absolute_position = find_line_number_of_relevant_line_in_file( self.get_diff_files(), relevant_file.strip('`'), relevant_line_in_file, - absolute_position + absolute_position, ) if position == -1: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + get_logger().info( + f"Could not find position for {relevant_file} {relevant_line_in_file}" + ) subject_type = "FILE" else: subject_type = "LINE" path = relevant_file.strip() - return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {} + return ( + dict(body=body, path=path, position=absolute_position) + if subject_type == "LINE" + else {} + ) - def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None): + def publish_inline_comment( + self, comment: str, from_line: int, file: str, original_suggestion=None + ): payload = { "text": comment, "severity": "NORMAL", @@ -291,17 +360,24 @@ class BitbucketServerProvider(GitProvider): "path": file, "lineType": "ADDED", "line": from_line, - "fileType": "TO" - } + "fileType": "TO", + }, } try: self.bitbucket_client.post(self._get_pr_comments_path(), data=payload) except Exception as e: - get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}") + get_logger().error( + f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}" + ) raise e - def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: + def get_line_link( + self, + relevant_file: str, + relevant_line_start: int, + relevant_line_end: int = None, + ) -> str: if relevant_line_start == -1: link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}" else: @@ -316,8 +392,9 @@ class BitbucketServerProvider(GitProvider): return "" diff_files = self.get_diff_files() - position, absolute_position = find_line_number_of_relevant_line_in_file \ - (diff_files, relevant_file, relevant_line_str) + position, absolute_position = find_line_number_of_relevant_line_in_file( + diff_files, relevant_file, relevant_line_str + ) if absolute_position != -1: if self.pr: @@ -325,29 +402,41 @@ class BitbucketServerProvider(GitProvider): return link else: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set") + get_logger().info( + f"Failed adding line link to '{relevant_file}' since PR not set" + ) else: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found") + get_logger().info( + f"Failed adding line link to '{relevant_file}' since position not found" + ) if absolute_position != -1 and self.pr_url: link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}" return link except Exception as e: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}") + get_logger().info( + f"Failed adding line link to '{relevant_file}', error: {e}" + ) return "" def publish_inline_comments(self, comments: list[dict]): for comment in comments: if 'position' in comment: - self.publish_inline_comment(comment['body'], comment['position'], comment['path']) - elif 'start_line' in comment: # multi-line comment + self.publish_inline_comment( + comment['body'], comment['position'], comment['path'] + ) + elif 'start_line' in comment: # multi-line comment # note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452 - self.publish_inline_comment(comment['body'], comment['start_line'], comment['path']) - elif 'line' in comment: # single-line comment - self.publish_inline_comment(comment['body'], comment['line'], comment['path']) + self.publish_inline_comment( + comment['body'], comment['start_line'], comment['path'] + ) + elif 'line' in comment: # single-line comment + self.publish_inline_comment( + comment['body'], comment['line'], comment['path'] + ) else: get_logger().error(f"Could not publish inline comment: {comment}") @@ -377,7 +466,9 @@ class BitbucketServerProvider(GitProvider): "Bitbucket provider does not support issue comments yet" ) - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: return True def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: @@ -411,14 +502,20 @@ class BitbucketServerProvider(GitProvider): users_index = -1 if projects_index == -1 and users_index == -1: - raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL") + raise ValueError( + f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL" + ) if projects_index != -1: path_parts = path_parts[projects_index:] else: path_parts = path_parts[users_index:] - if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests": + if ( + len(path_parts) < 6 + or path_parts[2] != "repos" + or path_parts[4] != "pull-requests" + ): raise ValueError( f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL" ) @@ -430,19 +527,24 @@ class BitbucketServerProvider(GitProvider): try: pr_number = int(path_parts[5]) except ValueError as e: - raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e + raise ValueError( + f"Unable to convert PR number '{path_parts[5]}' to integer" + ) from e return workspace_slug, repo_slug, pr_number def _get_repo(self): if self.repo is None: - self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug) + self.repo = self.bitbucket_client.get_repo( + self.workspace_slug, self.repo_slug + ) return self.repo def _get_pr(self): try: - pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug, - pull_request_id=self.pr_num) + pr = self.bitbucket_client.get_pull_request( + self.workspace_slug, self.repo_slug, pull_request_id=self.pr_num + ) return type('new_dict', (object,), pr) except Exception as e: get_logger().error(f"Failed to get pull request, error: {e}") @@ -460,10 +562,12 @@ class BitbucketServerProvider(GitProvider): "version": self.pr.version, "description": description, "title": pr_title, - "reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped + "reviewers": self.pr.reviewers, # needs to be sent otherwise gets wiped } try: - self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload) + self.bitbucket_client.update_pull_request( + self.workspace_slug, self.repo_slug, str(self.pr_num), payload + ) except Exception as e: get_logger().error(f"Failed to update pull request, error: {e}") raise e diff --git a/apps/utils/pr_agent/git_providers/codecommit_client.py b/apps/utils/pr_agent/git_providers/codecommit_client.py index 5f18c90..e720590 100644 --- a/apps/utils/pr_agent/git_providers/codecommit_client.py +++ b/apps/utils/pr_agent/git_providers/codecommit_client.py @@ -31,7 +31,9 @@ class CodeCommitPullRequestResponse: self.targets = [] for target in json.get("pullRequestTargets", []): - self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)) + self.targets.append( + CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target) + ) class CodeCommitPullRequestTarget: """ @@ -65,7 +67,9 @@ class CodeCommitClient: except Exception as e: raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e - def get_differences(self, repo_name: int, destination_commit: str, source_commit: str): + def get_differences( + self, repo_name: int, destination_commit: str, source_commit: str + ): """ Get the differences between two commits in CodeCommit. @@ -96,17 +100,25 @@ class CodeCommitClient: differences.extend(page.get("differences", [])) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': - raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e - raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e + raise ValueError( + f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}" + ) from e + raise ValueError( + f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}" + ) from e except Exception as e: - raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e + raise ValueError( + f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}" + ) from e output = [] for json in differences: output.append(CodeCommitDifferencesResponse(json)) return output - def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False): + def get_file( + self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False + ): """ Retrieve a file from CodeCommit. @@ -129,16 +141,24 @@ class CodeCommitClient: self._connect_boto_client() try: - response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path) + response = self.boto_client.get_file( + repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path + ) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': - raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e + raise ValueError( + f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}" + ) from e # if the file does not exist, but is flagged as optional, then return an empty string if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException': return "" - raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e + raise ValueError( + f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'" + ) from e except Exception as e: - raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e + raise ValueError( + f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'" + ) from e if "fileContent" not in response: raise ValueError(f"File content is empty for file: {file_path}") @@ -166,10 +186,16 @@ class CodeCommitClient: response = self.boto_client.get_pull_request(pullRequestId=str(pr_number)) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': - raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e + raise ValueError( + f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}" + ) from e if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': - raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e - raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e + raise ValueError( + f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}" + ) from e + raise ValueError( + f"CodeCommit cannot retrieve PR: {pr_number}: boto client error" + ) from e except Exception as e: raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e @@ -200,22 +226,37 @@ class CodeCommitClient: self._connect_boto_client() try: - self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title) - self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body) + self.boto_client.update_pull_request_title( + pullRequestId=str(pr_number), title=pr_title + ) + self.boto_client.update_pull_request_description( + pullRequestId=str(pr_number), description=pr_body + ) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': raise ValueError(f"PR number does not exist: {pr_number}") from e if e.response["Error"]["Code"] == 'InvalidTitleException': raise ValueError(f"Invalid title for PR number: {pr_number}") from e if e.response["Error"]["Code"] == 'InvalidDescriptionException': - raise ValueError(f"Invalid description for PR number: {pr_number}") from e + raise ValueError( + f"Invalid description for PR number: {pr_number}" + ) from e if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException': raise ValueError(f"PR is already closed: PR number: {pr_number}") from e raise ValueError(f"Boto3 client error calling publish_description") from e except Exception as e: raise ValueError(f"Error calling publish_description") from e - def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None): + def publish_comment( + self, + repo_name: str, + pr_number: int, + destination_commit: str, + source_commit: str, + comment: str, + annotation_file: str = None, + annotation_line: int = None, + ): """ Publish a comment to a pull request @@ -272,6 +313,8 @@ class CodeCommitClient: raise ValueError(f"Repository does not exist: {repo_name}") from e if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': raise ValueError(f"PR number does not exist: {pr_number}") from e - raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e + raise ValueError( + f"Boto3 client error calling post_comment_for_pull_request" + ) from e except Exception as e: raise ValueError(f"Error calling post_comment_for_pull_request") from e diff --git a/apps/utils/pr_agent/git_providers/codecommit_provider.py b/apps/utils/pr_agent/git_providers/codecommit_provider.py index 9e2669e..cbc663a 100644 --- a/apps/utils/pr_agent/git_providers/codecommit_provider.py +++ b/apps/utils/pr_agent/git_providers/codecommit_provider.py @@ -55,7 +55,9 @@ class CodeCommitProvider(GitProvider): This class implements the GitProvider interface for AWS CodeCommit repositories. """ - def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): + def __init__( + self, pr_url: Optional[str] = None, incremental: Optional[bool] = False + ): self.codecommit_client = CodeCommitClient() self.aws_client = None self.repo_name = None @@ -76,7 +78,7 @@ class CodeCommitProvider(GitProvider): "create_inline_comment", "publish_inline_comments", "get_labels", - "gfm_markdown" + "gfm_markdown", ]: return False return True @@ -91,13 +93,19 @@ class CodeCommitProvider(GitProvider): return self.git_files self.git_files = [] - differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit) + differences = self.codecommit_client.get_differences( + self.repo_name, self.pr.destination_commit, self.pr.source_commit + ) for item in differences: - self.git_files.append(CodeCommitFile(item.before_blob_path, - item.before_blob_id, - item.after_blob_path, - item.after_blob_id, - CodeCommitProvider._get_edit_type(item.change_type))) + self.git_files.append( + CodeCommitFile( + item.before_blob_path, + item.before_blob_id, + item.after_blob_path, + item.after_blob_id, + CodeCommitProvider._get_edit_type(item.change_type), + ) + ) return self.git_files def get_diff_files(self) -> list[FilePatchInfo]: @@ -121,21 +129,28 @@ class CodeCommitProvider(GitProvider): if diff_item.a_blob_id is not None: patch_filename = diff_item.a_path original_file_content_str = self.codecommit_client.get_file( - self.repo_name, diff_item.a_path, self.pr.destination_commit) + self.repo_name, diff_item.a_path, self.pr.destination_commit + ) if isinstance(original_file_content_str, (bytes, bytearray)): - original_file_content_str = original_file_content_str.decode("utf-8") + original_file_content_str = original_file_content_str.decode( + "utf-8" + ) else: original_file_content_str = "" if diff_item.b_blob_id is not None: patch_filename = diff_item.b_path - new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit) + new_file_content_str = self.codecommit_client.get_file( + self.repo_name, diff_item.b_path, self.pr.source_commit + ) if isinstance(new_file_content_str, (bytes, bytearray)): new_file_content_str = new_file_content_str.decode("utf-8") else: new_file_content_str = "" - patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str) + patch = load_large_diff( + patch_filename, new_file_content_str, original_file_content_str + ) # Store the diffs as a list of FilePatchInfo objects info = FilePatchInfo( @@ -164,7 +179,9 @@ class CodeCommitProvider(GitProvider): pr_body=CodeCommitProvider._add_additional_newlines(pr_body), ) except Exception as e: - raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e + raise ValueError( + f"CodeCommit Cannot publish description for PR: {self.pr_num}" + ) from e def publish_comment(self, pr_comment: str, is_temporary: bool = False): if is_temporary: @@ -183,19 +200,28 @@ class CodeCommitProvider(GitProvider): comment=pr_comment, ) except Exception as e: - raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e + raise ValueError( + f"CodeCommit Cannot publish comment for PR: {self.pr_num}" + ) from e def publish_code_suggestions(self, code_suggestions: list) -> bool: counter = 1 for suggestion in code_suggestions: # Verify that each suggestion has the required keys - if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]): - get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys") + if not all( + key in suggestion + for key in ["body", "relevant_file", "relevant_lines_start"] + ): + get_logger().warning( + f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys" + ) continue # Publish the code suggestion to CodeCommit try: - get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}") + get_logger().debug( + f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}" + ) self.codecommit_client.publish_comment( repo_name=self.repo_name, pr_number=self.pr_num, @@ -206,7 +232,9 @@ class CodeCommitProvider(GitProvider): annotation_line=suggestion["relevant_lines_start"], ) except Exception as e: - raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e + raise ValueError( + f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}" + ) from e counter += 1 @@ -227,12 +255,22 @@ class CodeCommitProvider(GitProvider): def remove_comment(self, comment): return "" # not implemented yet - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html - raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet") + raise NotImplementedError( + "CodeCommit provider does not support publishing inline comments yet" + ) def publish_inline_comments(self, comments: list[dict]): - raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet") + raise NotImplementedError( + "CodeCommit provider does not support publishing inline comments yet" + ) def get_title(self): return self.pr.title @@ -257,7 +295,7 @@ class CodeCommitProvider(GitProvider): - dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR. """ commit_files = self.get_files() - filenames = [ item.filename for item in commit_files ] + filenames = [item.filename for item in commit_files] extensions = CodeCommitProvider._get_file_extensions(filenames) # Calculate the percentage of each file extension in the PR @@ -270,7 +308,9 @@ class CodeCommitProvider(GitProvider): # We build that language->extension dictionary here in main_extensions_flat. main_extensions_flat = {} language_extension_map_org = get_settings().language_extension_map_org - language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} + language_extension_map = { + k.lower(): v for k, v in language_extension_map_org.items() + } for language, extensions in language_extension_map.items(): for ext in extensions: main_extensions_flat[ext] = language @@ -292,14 +332,20 @@ class CodeCommitProvider(GitProvider): return -1 # not implemented yet def get_issue_comments(self): - raise NotImplementedError("CodeCommit provider does not support issue comments yet") + raise NotImplementedError( + "CodeCommit provider does not support issue comments yet" + ) def get_repo_settings(self): # a local ".pr_agent.toml" settings file is optional settings_filename = ".pr_agent.toml" - return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True) + return self.codecommit_client.get_file( + self.repo_name, settings_filename, self.pr.source_commit, optional=True + ) - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: get_logger().info("CodeCommit provider does not support eyes reaction yet") return True @@ -323,7 +369,9 @@ class CodeCommitProvider(GitProvider): parsed_url = urlparse(pr_url) if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc): - raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}") + raise ValueError( + f"The provided URL is not a valid CodeCommit URL: {pr_url}" + ) path_parts = parsed_url.path.strip("/").split("/") @@ -334,14 +382,18 @@ class CodeCommitProvider(GitProvider): or path_parts[2] != "repositories" or path_parts[4] != "pull-requests" ): - raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}") + raise ValueError( + f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}" + ) repo_name = path_parts[3] try: pr_number = int(path_parts[5]) except ValueError as e: - raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e + raise ValueError( + f"Unable to convert PR number to integer: '{path_parts[5]}'" + ) from e return repo_name, pr_number @@ -359,7 +411,12 @@ class CodeCommitProvider(GitProvider): Returns: - bool: True if the hostname is valid, False otherwise. """ - return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None + return ( + re.match( + r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname + ) + is not None + ) def _get_pr(self): response = self.codecommit_client.get_pr(self.repo_name, self.pr_num) diff --git a/apps/utils/pr_agent/git_providers/gerrit_provider.py b/apps/utils/pr_agent/git_providers/gerrit_provider.py index 7ab4688..09c511e 100644 --- a/apps/utils/pr_agent/git_providers/gerrit_provider.py +++ b/apps/utils/pr_agent/git_providers/gerrit_provider.py @@ -38,10 +38,7 @@ def clone(url, directory): def fetch(url, refspec, cwd): get_logger().info("Fetching %s %s", url, refspec) - stdout = _call( - 'git', 'fetch', '--depth', '2', url, refspec, - cwd=cwd - ) + stdout = _call('git', 'fetch', '--depth', '2', url, refspec, cwd=cwd) get_logger().info(stdout) @@ -75,10 +72,13 @@ def add_comment(url: urllib3.util.Url, refspec, message): message = "'" + message.replace("'", "'\"'\"'") + "'" return _call( "ssh", - "-p", str(url.port), + "-p", + str(url.port), f"{url.auth}@{url.host}", - "gerrit", "review", - "--message", message, + "gerrit", + "review", + "--message", + message, # "--code-review", score, f"{patchset},{changenum}", ) @@ -88,19 +88,23 @@ def list_comments(url: urllib3.util.Url, refspec): *_, patchset, _ = refspec.rsplit("/") stdout = _call( "ssh", - "-p", str(url.port), + "-p", + str(url.port), f"{url.auth}@{url.host}", - "gerrit", "query", + "gerrit", + "query", "--comments", - "--current-patch-set", patchset, - "--format", "JSON", + "--current-patch-set", + patchset, + "--format", + "JSON", ) change_set, *_ = stdout.splitlines() return json.loads(change_set)["currentPatchSet"]["comments"] def prepare_repo(url: urllib3.util.Url, project, refspec): - repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}") + repo_url = f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}" directory = pathlib.Path(mkdtemp()) clone(repo_url, directory), @@ -114,18 +118,18 @@ def adopt_to_gerrit_message(message): buf = [] for line in lines: # remove markdown formatting - line = (line.replace("*", "") - .replace("``", "`") - .replace("
", "") - .replace("
", "") - .replace("", "") - .replace("", "")) + line = ( + line.replace("*", "") + .replace("``", "`") + .replace("
", "") + .replace("
", "") + .replace("", "") + .replace("", "") + ) line = line.strip() if line.startswith('#'): - buf.append("\n" + - line.replace('#', '').removesuffix(":").strip() + - ":") + buf.append("\n" + line.replace('#', '').removesuffix(":").strip() + ":") continue elif line.startswith('-'): buf.append(line.removeprefix('-').strip()) @@ -136,12 +140,9 @@ def adopt_to_gerrit_message(message): def add_suggestion(src_filename, context: str, start, end: int): - with ( - NamedTemporaryFile("w", delete=False) as tmp, - open(src_filename, "r") as src - ): + with NamedTemporaryFile("w", delete=False) as tmp, open(src_filename, "r") as src: lines = src.readlines() - tmp.writelines(lines[:start - 1]) + tmp.writelines(lines[: start - 1]) if context: tmp.write(context) tmp.writelines(lines[end:]) @@ -151,10 +152,8 @@ def add_suggestion(src_filename, context: str, start, end: int): def upload_patch(patch, path): - patch_server_endpoint = get_settings().get( - 'gerrit.patch_server_endpoint') - patch_server_token = get_settings().get( - 'gerrit.patch_server_token') + patch_server_endpoint = get_settings().get('gerrit.patch_server_endpoint') + patch_server_token = get_settings().get('gerrit.patch_server_token') response = requests.post( patch_server_endpoint, @@ -165,7 +164,7 @@ def upload_patch(patch, path): headers={ "Content-Type": "application/json", "Authorization": f"Bearer {patch_server_token}", - } + }, ) response.raise_for_status() patch_server_endpoint = patch_server_endpoint.rstrip("/") @@ -173,7 +172,6 @@ def upload_patch(patch, path): class GerritProvider(GitProvider): - def __init__(self, key: str, incremental=False): self.project, self.refspec = key.split(':') assert self.project, "Project name is required" @@ -188,9 +186,7 @@ class GerritProvider(GitProvider): f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}" ) - self.repo_path = prepare_repo( - self.parsed_url, self.project, self.refspec - ) + self.repo_path = prepare_repo(self.parsed_url, self.project, self.refspec) self.repo = Repo(self.repo_path) assert self.repo self.pr_url = base_url @@ -210,15 +206,18 @@ class GerritProvider(GitProvider): def get_pr_labels(self, update=False): raise NotImplementedError( - 'Getting labels is not implemented for the gerrit provider') + 'Getting labels is not implemented for the gerrit provider' + ) def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False): raise NotImplementedError( - 'Adding reactions is not implemented for the gerrit provider') + 'Adding reactions is not implemented for the gerrit provider' + ) def remove_reaction(self, issue_comment_id: int, reaction_id: int): raise NotImplementedError( - 'Removing reactions is not implemented for the gerrit provider') + 'Removing reactions is not implemented for the gerrit provider' + ) def get_commit_messages(self): return [self.repo.head.commit.message] @@ -235,20 +234,21 @@ class GerritProvider(GitProvider): diffs = self.repo.head.commit.diff( self.repo.head.commit.parents[0], # previous commit create_patch=True, - R=True + R=True, ) diff_files = [] for diff_item in diffs: if diff_item.a_blob is not None: - original_file_content_str = ( - diff_item.a_blob.data_stream.read().decode('utf-8') + original_file_content_str = diff_item.a_blob.data_stream.read().decode( + 'utf-8' ) else: original_file_content_str = "" # empty file if diff_item.b_blob is not None: - new_file_content_str = diff_item.b_blob.data_stream.read(). \ - decode('utf-8') + new_file_content_str = diff_item.b_blob.data_stream.read().decode( + 'utf-8' + ) else: new_file_content_str = "" # empty file edit_type = EDIT_TYPE.MODIFIED @@ -267,7 +267,7 @@ class GerritProvider(GitProvider): edit_type=edit_type, old_filename=None if diff_item.a_path == diff_item.b_path - else diff_item.a_path + else diff_item.a_path, ) ) self.diff_files = diff_files @@ -275,8 +275,7 @@ class GerritProvider(GitProvider): def get_files(self): diff_index = self.repo.head.commit.diff( - self.repo.head.commit.parents[0], # previous commit - R=True + self.repo.head.commit.parents[0], R=True # previous commit ) # Get the list of changed files diff_files = [item.a_path for item in diff_index] @@ -288,16 +287,22 @@ class GerritProvider(GitProvider): prioritisation. """ # Get all files in repository - filepaths = [Path(item.path) for item in - self.repo.tree().traverse() if item.type == 'blob'] + filepaths = [ + Path(item.path) + for item in self.repo.tree().traverse() + if item.type == 'blob' + ] # Identify language by file extension and count lang_count = Counter( - ext.lstrip('.') for filepath in filepaths for ext in - [filepath.suffix.lower()]) + ext.lstrip('.') + for filepath in filepaths + for ext in [filepath.suffix.lower()] + ) # Convert counts to percentages total_files = len(filepaths) - lang_percentage = {lang: count / total_files * 100 for lang, count - in lang_count.items()} + lang_percentage = { + lang: count / total_files * 100 for lang, count in lang_count.items() + } return lang_percentage def get_pr_description_full(self): @@ -312,7 +317,7 @@ class GerritProvider(GitProvider): 'create_inline_comment', 'publish_inline_comments', 'get_labels', - 'gfm_markdown' + 'gfm_markdown', ]: return False return True @@ -331,14 +336,9 @@ class GerritProvider(GitProvider): if is_code_context: context.append(line) else: - description.append( - line.replace('*', '') - ) + description.append(line.replace('*', '')) - return ( - '\n'.join(description), - '\n'.join(context) + '\n' if context else '' - ) + return ('\n'.join(description), '\n'.join(context) + '\n' if context else '') def publish_code_suggestions(self, code_suggestions: list): msg = [] @@ -372,15 +372,19 @@ class GerritProvider(GitProvider): def publish_inline_comments(self, comments: list[dict]): raise NotImplementedError( - 'Publishing inline comments is not implemented for the gerrit ' - 'provider') + 'Publishing inline comments is not implemented for the gerrit ' 'provider' + ) - def publish_inline_comment(self, body: str, relevant_file: str, - relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): raise NotImplementedError( - 'Publishing inline comments is not implemented for the gerrit ' - 'provider') - + 'Publishing inline comments is not implemented for the gerrit ' 'provider' + ) def publish_labels(self, labels): # Not applicable to the local git provider, diff --git a/apps/utils/pr_agent/git_providers/git_provider.py b/apps/utils/pr_agent/git_providers/git_provider.py index c8331ec..e288f22 100644 --- a/apps/utils/pr_agent/git_providers/git_provider.py +++ b/apps/utils/pr_agent/git_providers/git_provider.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod + # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) from typing import Optional @@ -9,6 +10,7 @@ from utils.pr_agent.log import get_logger MAX_FILES_ALLOWED_FULL = 50 + class GitProvider(ABC): @abstractmethod def is_supported(self, capability: str) -> bool: @@ -61,11 +63,18 @@ class GitProvider(ABC): def reply_to_comment_from_comment_id(self, comment_id: int, body: str): pass - def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple: + def get_pr_description( + self, full: bool = True, split_changes_walkthrough=False + ) -> str or tuple: from utils.pr_agent.algo.utils import clip_tokens from utils.pr_agent.config_loader import get_settings - max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None) - description = self.get_pr_description_full() if full else self.get_user_description() + + max_tokens_description = get_settings().get( + "CONFIG.MAX_DESCRIPTION_TOKENS", None + ) + description = ( + self.get_pr_description_full() if full else self.get_user_description() + ) if split_changes_walkthrough: description, files = process_description(description) if max_tokens_description: @@ -94,7 +103,9 @@ class GitProvider(ABC): # return nothing (empty string) because it means there is no user description user_description_header = "### **user description**" if user_description_header not in description_lowercase: - get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description") + get_logger().info( + f"Existing description was generated by the pr-agent, but it doesn't contain a user description" + ) return "" # otherwise, extract the original user description from the existing pr-agent description and return it @@ -103,9 +114,11 @@ class GitProvider(ABC): # the 'user description' is in the beginning. extract and return it possible_headers = self._possible_headers() - start_position = description_lowercase.find(user_description_header) + len(user_description_header) + start_position = description_lowercase.find(user_description_header) + len( + user_description_header + ) end_position = len(description) - for header in possible_headers: # try to clip at the next header + for header in possible_headers: # try to clip at the next header if header != user_description_header and header in description_lowercase: end_position = min(end_position, description_lowercase.find(header)) if end_position != len(description) and end_position > start_position: @@ -115,20 +128,34 @@ class GitProvider(ABC): else: original_user_description = description.split("___")[0].strip() if original_user_description.lower().startswith(user_description_header): - original_user_description = original_user_description[len(user_description_header):].strip() + original_user_description = original_user_description[ + len(user_description_header) : + ].strip() - get_logger().info(f"Extracted user description from existing description", - description=original_user_description) + get_logger().info( + f"Extracted user description from existing description", + description=original_user_description, + ) self.user_description = original_user_description return original_user_description def _possible_headers(self): - return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**", - "### **labels**", "### 🤖 generated by pr agent") + return ( + "### **user description**", + "### **pr type**", + "### **pr description**", + "### **pr labels**", + "### **type**", + "### **description**", + "### **labels**", + "### 🤖 generated by pr agent", + ) def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool: possible_headers = self._possible_headers() - return any(description_lowercase.startswith(header) for header in possible_headers) + return any( + description_lowercase.startswith(header) for header in possible_headers + ) @abstractmethod def get_repo_settings(self): @@ -140,10 +167,17 @@ class GitProvider(ABC): def get_pr_id(self): return "" - def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: + def get_line_link( + self, + relevant_file: str, + relevant_line_start: int, + relevant_line_end: int = None, + ) -> str: return "" - def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str: + def get_lines_link_original_file( + self, filepath: str, component_range: Range + ) -> str: return "" #### comments operations #### @@ -151,18 +185,24 @@ class GitProvider(ABC): def publish_comment(self, pr_comment: str, is_temporary: bool = False): pass - def publish_persistent_comment(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): + def publish_persistent_comment( + self, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + ): self.publish_comment(pr_comment) - def publish_persistent_comment_full(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): + def publish_persistent_comment_full( + self, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + ): try: prev_comments = list(self.get_issue_comments()) for comment in prev_comments: @@ -171,29 +211,46 @@ class GitProvider(ABC): comment_url = self.get_comment_url(comment) if update_header: updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n" - pr_comment_updated = pr_comment.replace(initial_header, updated_header) + pr_comment_updated = pr_comment.replace( + initial_header, updated_header + ) else: pr_comment_updated = pr_comment - get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message") + get_logger().info( + f"Persistent mode - updating comment {comment_url} to latest {name} message" + ) # response = self.mr.notes.update(comment.id, {'body': pr_comment_updated}) self.edit_comment(comment, pr_comment_updated) if final_update_message: self.publish_comment( - f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}") + f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}" + ) return except Exception as e: get_logger().exception(f"Failed to update persistent review, error: {e}") pass self.publish_comment(pr_comment) - @abstractmethod - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): pass - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): - raise NotImplementedError("This git provider does not support creating inline comments yet") + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): + raise NotImplementedError( + "This git provider does not support creating inline comments yet" + ) @abstractmethod def publish_inline_comments(self, comments: list[dict]): @@ -227,7 +284,9 @@ class GitProvider(ABC): pass @abstractmethod - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: pass @abstractmethod @@ -284,16 +343,23 @@ def get_main_pr_language(languages, files) -> str: if not file: continue if isinstance(file, str): - file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file) + file = FilePatchInfo( + base_file=None, head_file=None, patch=None, filename=file + ) extension_list.append(file.filename.rsplit('.')[-1]) # get the most common extension most_common_extension = '.' + max(set(extension_list), key=extension_list.count) try: language_extension_map_org = get_settings().language_extension_map_org - language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} + language_extension_map = { + k.lower(): v for k, v in language_extension_map_org.items() + } - if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]: + if ( + top_language in language_extension_map + and most_common_extension in language_extension_map[top_language] + ): main_language_str = top_language else: for language, extensions in language_extension_map.items(): @@ -332,8 +398,6 @@ def get_main_pr_language(languages, files) -> str: return main_language_str - - class IncrementalPR: def __init__(self, is_incremental: bool = False): self.is_incremental = is_incremental diff --git a/apps/utils/pr_agent/git_providers/github_provider.py b/apps/utils/pr_agent/git_providers/github_provider.py index 6a24965..dbd56b4 100644 --- a/apps/utils/pr_agent/git_providers/github_provider.py +++ b/apps/utils/pr_agent/git_providers/github_provider.py @@ -18,14 +18,23 @@ from ..algo.file_filter import filter_ignored from ..algo.git_patch_processing import extract_hunk_headers from ..algo.language_handler import is_valid_file from ..algo.types import EDIT_TYPE -from ..algo.utils import (PRReviewHeader, Range, clip_tokens, - find_line_number_of_relevant_line_in_file, - load_large_diff, set_file_languages) +from ..algo.utils import ( + PRReviewHeader, + Range, + clip_tokens, + find_line_number_of_relevant_line_in_file, + load_large_diff, + set_file_languages, +) from ..config_loader import get_settings from ..log import get_logger from ..servers.utils import RateLimitExceeded -from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider, - IncrementalPR) +from .git_provider import ( + MAX_FILES_ALLOWED_FULL, + FilePatchInfo, + GitProvider, + IncrementalPR, +) class GithubProvider(GitProvider): @@ -36,8 +45,14 @@ class GithubProvider(GitProvider): except Exception: self.installation_id = None self.max_comment_chars = 65000 - self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" - self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com" + self.base_url = ( + get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") + ) # "https://api.github.com" + self.base_url_html = ( + self.base_url.split("api/")[0].rstrip("/") + if "api/" in self.base_url + else "https://github.com" + ) self.github_client = self._get_github_client() self.repo = None self.pr_num = None @@ -50,7 +65,9 @@ class GithubProvider(GitProvider): self.set_pr(pr_url) self.pr_commits = list(self.pr.get_commits()) self.last_commit_id = self.pr_commits[-1] - self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object + self.pr_url = ( + self.get_pr_url() + ) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object else: self.pr_commits = None @@ -80,10 +97,14 @@ class GithubProvider(GitProvider): # Get all files changed during the commit range for commit in self.incremental.commits_range: - if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"): + if commit.commit.message.startswith( + f"Merge branch '{self._get_repo().default_branch}'" + ): get_logger().info(f"Skipping merge commit {commit.commit.message}") continue - self.unreviewed_files_set.update({file.filename: file for file in commit.files}) + self.unreviewed_files_set.update( + {file.filename: file for file in commit.files} + ) else: get_logger().info("No previous review found, will review the entire PR") self.incremental.is_incremental = False @@ -98,7 +119,11 @@ class GithubProvider(GitProvider): else: self.incremental.last_seen_commit = self.pr_commits[index] break - return self.pr_commits[first_new_commit_index:] if first_new_commit_index is not None else [] + return ( + self.pr_commits[first_new_commit_index:] + if first_new_commit_index is not None + else [] + ) def get_previous_review(self, *, full: bool, incremental: bool): if not (full or incremental): @@ -121,7 +146,7 @@ class GithubProvider(GitProvider): git_files = context.get("git_files", None) if git_files: return git_files - self.git_files = list(self.pr.get_files()) # 'list' to handle pagination + self.git_files = list(self.pr.get_files()) # 'list' to handle pagination context["git_files"] = self.git_files return self.git_files except Exception: @@ -138,8 +163,13 @@ class GithubProvider(GitProvider): except Exception as e: return -1 - @retry(exceptions=RateLimitExceeded, - tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) + @retry( + exceptions=RateLimitExceeded, + tries=get_settings().github.ratelimit_retries, + delay=2, + backoff=2, + jitter=(1, 3), + ) def get_diff_files(self) -> list[FilePatchInfo]: """ Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub, @@ -167,9 +197,10 @@ class GithubProvider(GitProvider): try: names_original = [file.filename for file in files_original] names_new = [file.filename for file in files] - get_logger().info(f"Filtered out [ignore] files for pull request:", extra= - {"files": names_original, - "filtered_files": names_new}) + get_logger().info( + f"Filtered out [ignore] files for pull request:", + extra={"files": names_original, "filtered_files": names_new}, + ) except Exception: pass @@ -184,14 +215,17 @@ class GithubProvider(GitProvider): repo = self.repo_obj pr = self.pr try: - compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub + compare = repo.compare( + pr.base.sha, pr.head.sha + ) # communication with GitHub merge_base_commit = compare.merge_base_commit except Exception as e: get_logger().error(f"Failed to get merge base commit: {e}") merge_base_commit = pr.base if merge_base_commit.sha != pr.base.sha: get_logger().info( - f"Using merge base commit {merge_base_commit.sha} instead of base commit ") + f"Using merge base commit {merge_base_commit.sha} instead of base commit " + ) counter_valid = 0 for file in files: @@ -207,29 +241,48 @@ class GithubProvider(GitProvider): # allow only a limited number of files to be fully loaded. We can manage the rest with diffs only counter_valid += 1 avoid_load = False - if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental: + if ( + counter_valid >= MAX_FILES_ALLOWED_FULL + and patch + and not self.incremental.is_incremental + ): avoid_load = True if counter_valid == MAX_FILES_ALLOWED_FULL: - get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files") + get_logger().info( + f"Too many files in PR, will avoid loading full content for rest of files" + ) if avoid_load: new_file_content_str = "" else: - new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub + new_file_content_str = self._get_pr_file_content( + file, self.pr.head.sha + ) # communication with GitHub if self.incremental.is_incremental and self.unreviewed_files_set: - original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) - patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) + original_file_content_str = self._get_pr_file_content( + file, self.incremental.last_seen_commit_sha + ) + patch = load_large_diff( + file.filename, + new_file_content_str, + original_file_content_str, + ) self.unreviewed_files_set[file.filename] = patch else: if avoid_load: original_file_content_str = "" else: - original_file_content_str = self._get_pr_file_content(file, merge_base_commit.sha) + original_file_content_str = self._get_pr_file_content( + file, merge_base_commit.sha + ) # original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) if not patch: - patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) - + patch = load_large_diff( + file.filename, + new_file_content_str, + original_file_content_str, + ) if file.status == 'added': edit_type = EDIT_TYPE.ADDED @@ -249,16 +302,27 @@ class GithubProvider(GitProvider): num_minus_lines = file.deletions else: patch_lines = patch.splitlines(keepends=True) - num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) - num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) + num_plus_lines = len( + [line for line in patch_lines if line.startswith('+')] + ) + num_minus_lines = len( + [line for line in patch_lines if line.startswith('-')] + ) - file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch, - file.filename, edit_type=edit_type, - num_plus_lines=num_plus_lines, - num_minus_lines=num_minus_lines,) + file_patch_canonical_structure = FilePatchInfo( + original_file_content_str, + new_file_content_str, + patch, + file.filename, + edit_type=edit_type, + num_plus_lines=num_plus_lines, + num_minus_lines=num_minus_lines, + ) diff_files.append(file_patch_canonical_structure) if invalid_files_names: - get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}") + get_logger().info( + f"Filtered out files with invalid extensions: {invalid_files_names}" + ) self.diff_files = diff_files try: @@ -269,8 +333,10 @@ class GithubProvider(GitProvider): return diff_files except Exception as e: - get_logger().error(f"Failing to get diff files: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failing to get diff files: {e}", + artifact={"traceback": traceback.format_exc()}, + ) raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e def publish_description(self, pr_title: str, pr_body: str): @@ -282,16 +348,23 @@ class GithubProvider(GitProvider): def get_comment_url(self, comment) -> str: return comment.html_url - def publish_persistent_comment(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): - self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) + def publish_persistent_comment( + self, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + ): + self.publish_persistent_comment_full( + pr_comment, initial_header, update_header, name, final_update_message + ) def publish_comment(self, pr_comment: str, is_temporary: bool = False): if is_temporary and not get_settings().config.publish_output_progress: - get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") + get_logger().debug( + f"Skipping publish_comment for temporary comment: {pr_comment}" + ) return None pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars) response = self.pr.create_issue_comment(pr_comment) @@ -303,42 +376,68 @@ class GithubProvider(GitProvider): self.pr.comments_list.append(response) return response - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): body = self.limit_output_characters(body, self.max_comment_chars) - self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) + self.publish_inline_comments( + [self.create_inline_comment(body, relevant_file, relevant_line_in_file)] + ) - - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): body = self.limit_output_characters(body, self.max_comment_chars) - position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files, - relevant_file.strip('`'), - relevant_line_in_file, - absolute_position) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.diff_files, + relevant_file.strip('`'), + relevant_line_in_file, + absolute_position, + ) if position == -1: - get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + get_logger().info( + f"Could not find position for {relevant_file} {relevant_line_in_file}" + ) subject_type = "FILE" else: subject_type = "LINE" path = relevant_file.strip() - return dict(body=body, path=path, position=position) if subject_type == "LINE" else {} + return ( + dict(body=body, path=path, position=position) + if subject_type == "LINE" + else {} + ) - def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False): + def publish_inline_comments( + self, comments: list[dict], disable_fallback: bool = False + ): try: # publish all comments in a single message self.pr.create_review(commit=self.last_commit_id, comments=comments) except Exception as e: - get_logger().info(f"Initially failed to publish inline comments as committable") + get_logger().info( + f"Initially failed to publish inline comments as committable" + ) - if (getattr(e, "status", None) == 422 and not disable_fallback): + if getattr(e, "status", None) == 422 and not disable_fallback: pass # continue to try _publish_inline_comments_fallback_with_verification else: - raise e # will end up with publishing the comments one by one + raise e # will end up with publishing the comments one by one try: self._publish_inline_comments_fallback_with_verification(comments) except Exception as e: - get_logger().error(f"Failed to publish inline code comments fallback, error: {e}") + get_logger().error( + f"Failed to publish inline code comments fallback, error: {e}" + ) raise e def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]): @@ -352,20 +451,27 @@ class GithubProvider(GitProvider): # publish as a group the verified comments if verified_comments: try: - self.pr.create_review(commit=self.last_commit_id, comments=verified_comments) + self.pr.create_review( + commit=self.last_commit_id, comments=verified_comments + ) except: pass # try to publish one by one the invalid comments as a one-line code comment if invalid_comments and get_settings().github.try_fix_invalid_inline_comments: fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments( - [comment for comment, _ in invalid_comments]) + [comment for comment, _ in invalid_comments] + ) for comment in fixed_comments_as_one_liner: try: self.publish_inline_comments([comment], disable_fallback=True) - get_logger().info(f"Published invalid comment as a single line comment: {comment}") + get_logger().info( + f"Published invalid comment as a single line comment: {comment}" + ) except: - get_logger().error(f"Failed to publish invalid comment as a single line comment: {comment}") + get_logger().error( + f"Failed to publish invalid comment as a single line comment: {comment}" + ) def _verify_code_comment(self, comment: dict): is_verified = False @@ -374,7 +480,8 @@ class GithubProvider(GitProvider): # event ="" # By leaving this blank, you set the review action state to PENDING input = dict(commit_id=self.last_commit_id.sha, comments=[comment]) headers, data = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.pr.url}/reviews", input=input) + "POST", f"{self.pr.url}/reviews", input=input + ) pending_review_id = data["id"] is_verified = True except Exception as err: @@ -383,12 +490,16 @@ class GithubProvider(GitProvider): e = err if pending_review_id is not None: try: - self.pr._requester.requestJsonAndCheck("DELETE", f"{self.pr.url}/reviews/{pending_review_id}") + self.pr._requester.requestJsonAndCheck( + "DELETE", f"{self.pr.url}/reviews/{pending_review_id}" + ) except Exception: pass return is_verified, e - def _verify_code_comments(self, comments: list[dict]) -> tuple[list[dict], list[tuple[dict, Exception]]]: + def _verify_code_comments( + self, comments: list[dict] + ) -> tuple[list[dict], list[tuple[dict, Exception]]]: """Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments""" verified_comments = [] invalid_comments = [] @@ -401,17 +512,22 @@ class GithubProvider(GitProvider): invalid_comments.append((comment, e)) return verified_comments, invalid_comments - def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list[dict]: + def _try_fix_invalid_inline_comments( + self, invalid_comments: list[dict] + ) -> list[dict]: """ Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line. Return only comments that have been modified in some way. This is a best-effort attempt to fix invalid comments, and should be verified accordingly. """ import copy + fixed_comments = [] for comment in invalid_comments: try: - fixed_comment = copy.deepcopy(comment) # avoid modifying the original comment dict for later logging + fixed_comment = copy.deepcopy( + comment + ) # avoid modifying the original comment dict for later logging if "```suggestion" in comment["body"]: fixed_comment["body"] = comment["body"].split("```suggestion")[0] if "start_line" in comment: @@ -432,7 +548,9 @@ class GithubProvider(GitProvider): """ post_parameters_list = [] - code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions) + code_suggestions_validated = self.validate_comments_inside_hunks( + code_suggestions + ) for suggestion in code_suggestions_validated: body = suggestion['body'] @@ -442,13 +560,16 @@ class GithubProvider(GitProvider): if not relevant_lines_start or relevant_lines_start == -1: get_logger().exception( - f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") + f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end < relevant_lines_start: - get_logger().exception(f"Failed to publish code suggestion, " - f"relevant_lines_end is {relevant_lines_end} and " - f"relevant_lines_start is {relevant_lines_start}") + get_logger().exception( + f"Failed to publish code suggestion, " + f"relevant_lines_end is {relevant_lines_end} and " + f"relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end > relevant_lines_start: @@ -484,17 +605,21 @@ class GithubProvider(GitProvider): # Log as warning for permission-related issues (usually due to polling) get_logger().warning( "Failed to edit github comment due to permission restrictions", - artifact={"error": e}) + artifact={"error": e}, + ) else: - get_logger().exception(f"Failed to edit github comment", artifact={"error": e}) + get_logger().exception( + f"Failed to edit github comment", artifact={"error": e} + ) def edit_comment_from_comment_id(self, comment_id: int, body: str): try: # self.pr.get_issue_comment(comment_id).edit(body) body = self.limit_output_characters(body, self.max_comment_chars) headers, data_patch = self.pr._requester.requestJsonAndCheck( - "PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}", - input={"body": body} + "PATCH", + f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}", + input={"body": body}, ) except Exception as e: get_logger().exception(f"Failed to edit comment, error: {e}") @@ -504,8 +629,9 @@ class GithubProvider(GitProvider): # self.pr.get_issue_comment(comment_id).edit(body) body = self.limit_output_characters(body, self.max_comment_chars) headers, data_patch = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies", - input={"body": body} + "POST", + f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies", + input={"body": body}, ) except Exception as e: get_logger().exception(f"Failed to reply comment, error: {e}") @@ -516,7 +642,7 @@ class GithubProvider(GitProvider): headers, data_patch = self.pr._requester.requestJsonAndCheck( "GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}" ) - return data_patch.get("body","") + return data_patch.get("body", "") except Exception as e: get_logger().exception(f"Failed to edit comment, error: {e}") return None @@ -528,7 +654,9 @@ class GithubProvider(GitProvider): ) for comment in file_comments: comment['commit_id'] = self.last_commit_id.sha - comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars) + comment['body'] = self.limit_output_characters( + comment['body'], self.max_comment_chars + ) found = False for existing_comment in existing_comments: @@ -536,13 +664,23 @@ class GithubProvider(GitProvider): our_app_name = get_settings().get("GITHUB.APP_NAME", "") same_comment_creator = False if self.deployment_type == 'app': - same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower() + same_comment_creator = ( + our_app_name.lower() + in existing_comment['user']['login'].lower() + ) elif self.deployment_type == 'user': - same_comment_creator = self.github_user_id == existing_comment['user']['login'] - if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator: - + same_comment_creator = ( + self.github_user_id == existing_comment['user']['login'] + ) + if ( + existing_comment['subject_type'] == 'file' + and comment['path'] == existing_comment['path'] + and same_comment_creator + ): headers, data_patch = self.pr._requester.requestJsonAndCheck( - "PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']} + "PATCH", + f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", + input={"body": comment['body']}, ) found = True break @@ -600,7 +738,9 @@ class GithubProvider(GitProvider): deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") if deployment_type != 'user': - raise ValueError("Deployment mode must be set to 'user' to get notifications") + raise ValueError( + "Deployment mode must be set to 'user' to get notifications" + ) notifications = self.github_client.get_user().get_notifications(since=since) return notifications @@ -621,13 +761,16 @@ class GithubProvider(GitProvider): def get_workspace_name(self): return self.repo.split('/')[0] - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: if disable_eyes: return None try: headers, data_patch = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions", - input={"content": "eyes"} + "POST", + f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions", + input={"content": "eyes"}, ) return data_patch.get("id", None) except Exception as e: @@ -639,7 +782,7 @@ class GithubProvider(GitProvider): # self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) headers, data_patch = self.pr._requester.requestJsonAndCheck( "DELETE", - f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}" + f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}", ) return True except Exception as e: @@ -655,7 +798,9 @@ class GithubProvider(GitProvider): path_parts = parsed_url.path.strip('/').split('/') if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url: if len(path_parts) < 5 or path_parts[3] != 'pulls': - raise ValueError("The provided URL does not appear to be a GitHub PR URL") + raise ValueError( + "The provided URL does not appear to be a GitHub PR URL" + ) repo_name = '/'.join(path_parts[1:3]) try: pr_number = int(path_parts[4]) @@ -683,7 +828,9 @@ class GithubProvider(GitProvider): path_parts = parsed_url.path.strip('/').split('/') if 'api.github.com' in parsed_url.netloc: if len(path_parts) < 5 or path_parts[3] != 'issues': - raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL") + raise ValueError( + "The provided URL does not appear to be a GitHub ISSUE URL" + ) repo_name = '/'.join(path_parts[1:3]) try: issue_number = int(path_parts[4]) @@ -710,11 +857,18 @@ class GithubProvider(GitProvider): private_key = get_settings().github.private_key app_id = get_settings().github.app_id except AttributeError as e: - raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e + raise ValueError( + "GitHub app ID and private key are required when using GitHub app deployment" + ) from e if not self.installation_id: - raise ValueError("GitHub app installation ID is required when using GitHub app deployment") - auth = AppAuthentication(app_id=app_id, private_key=private_key, - installation_id=self.installation_id) + raise ValueError( + "GitHub app installation ID is required when using GitHub app deployment" + ) + auth = AppAuthentication( + app_id=app_id, + private_key=private_key, + installation_id=self.installation_id, + ) return Github(app_auth=auth, base_url=self.base_url) if deployment_type == 'user': @@ -723,19 +877,21 @@ class GithubProvider(GitProvider): except AttributeError as e: raise ValueError( "GitHub token is required when using user deployment. See: " - "https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e + "https://github.com/Codium-ai/pr-agent#method-2-run-from-source" + ) from e return Github(auth=Auth.Token(token), base_url=self.base_url) def _get_repo(self): - if hasattr(self, 'repo_obj') and \ - hasattr(self.repo_obj, 'full_name') and \ - self.repo_obj.full_name == self.repo: + if ( + hasattr(self, 'repo_obj') + and hasattr(self.repo_obj, 'full_name') + and self.repo_obj.full_name == self.repo + ): return self.repo_obj else: self.repo_obj = self.github_client.get_repo(self.repo) return self.repo_obj - def _get_pr(self): return self._get_repo().get_pull(self.pr_num) @@ -755,9 +911,9 @@ class GithubProvider(GitProvider): ) -> None: try: file_obj = self._get_repo().get_contents(file_path, ref=branch) - sha1=file_obj.sha + sha1 = file_obj.sha except Exception: - sha1="" + sha1 = "" self.repo_obj.update_file( path=file_path, message=message, @@ -771,9 +927,14 @@ class GithubProvider(GitProvider): def publish_labels(self, pr_types): try: - label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", - "Enhancement": "bfd4f2", "Documentation": "d4c5f9", - "Other": "d1bcf9"} + label_color_map = { + "Bug fix": "1d76db", + "Tests": "e99695", + "Bug fix with tests": "c5def5", + "Enhancement": "bfd4f2", + "Documentation": "d4c5f9", + "Other": "d1bcf9", + } post_parameters = [] for p in pr_types: color = label_color_map.get(p, "d1bcf9") # default to "Other" color @@ -787,11 +948,12 @@ class GithubProvider(GitProvider): def get_pr_labels(self, update=False): try: if not update: - labels =self.pr.labels + labels = self.pr.labels return [label.name for label in labels] - else: # obtain the latest labels. Maybe they changed while the AI was running + else: # obtain the latest labels. Maybe they changed while the AI was running headers, labels = self.pr._requester.requestJsonAndCheck( - "GET", f"{self.pr.issue_url}/labels") + "GET", f"{self.pr.issue_url}/labels" + ) return [label['name'] for label in labels] except Exception as e: @@ -813,7 +975,9 @@ class GithubProvider(GitProvider): try: commit_list = self.pr.get_commits() commit_messages = [commit.commit.message for commit in commit_list] - commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) + commit_messages_str = "\n".join( + [f"{i + 1}. {message}" for i, message in enumerate(commit_messages)] + ) except Exception: commit_messages_str = "" if max_tokens: @@ -822,13 +986,16 @@ class GithubProvider(GitProvider): def generate_link_to_relevant_line_number(self, suggestion) -> str: try: - relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n') + relevant_file = ( + suggestion['relevant_file'].strip('`').strip("'").strip('\n') + ) relevant_line_str = suggestion['relevant_line'].strip('\n') if not relevant_line_str: return "" - position, absolute_position = find_line_number_of_relevant_line_in_file \ - (self.diff_files, relevant_file, relevant_line_str) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.diff_files, relevant_file, relevant_line_str + ) if absolute_position != -1: # # link to right file only @@ -844,7 +1011,12 @@ class GithubProvider(GitProvider): return "" - def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: + def get_line_link( + self, + relevant_file: str, + relevant_line_start: int, + relevant_line_end: int = None, + ) -> str: sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest() if relevant_line_start == -1: link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}" @@ -854,7 +1026,9 @@ class GithubProvider(GitProvider): link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}" return link - def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str: + def get_lines_link_original_file( + self, filepath: str, component_range: Range + ) -> str: """ Returns the link to the original file on GitHub that corresponds to the given filepath and component range. @@ -876,8 +1050,10 @@ class GithubProvider(GitProvider): line_end = component_range.line_end + 1 # link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" # f"#L{line_start}-L{line_end}") - link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" - f"#L{line_start}-L{line_end}") + link = ( + f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" + f"#L{line_start}-L{line_end}" + ) return link @@ -909,8 +1085,9 @@ class GithubProvider(GitProvider): }} }} """ - response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", - input={"query": query}) + response_tuple = self.github_client._Github__requester.requestJson( + "POST", "/graphql", input={"query": query} + ) # Extract the JSON response from the tuple and parses it if isinstance(response_tuple, tuple) and len(response_tuple) == 3: @@ -919,8 +1096,12 @@ class GithubProvider(GitProvider): get_logger().error(f"Unexpected response format: {response_tuple}") return sub_issues - - issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id") + issue_id = ( + response_json.get("data", {}) + .get("repository", {}) + .get("issue", {}) + .get("id") + ) if not issue_id: get_logger().warning(f"Issue ID not found for {issue_url}") @@ -940,22 +1121,42 @@ class GithubProvider(GitProvider): }} }} """ - sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={ - "query": sub_issues_query}) + sub_issues_response_tuple = ( + self.github_client._Github__requester.requestJson( + "POST", "/graphql", input={"query": sub_issues_query} + ) + ) # Extract the JSON response from the tuple and parses it - if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3: + if ( + isinstance(sub_issues_response_tuple, tuple) + and len(sub_issues_response_tuple) == 3 + ): sub_issues_response_json = json.loads(sub_issues_response_tuple[2]) else: - get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple}) + get_logger().error( + "Unexpected sub-issues response format", + artifact={"response": sub_issues_response_tuple}, + ) return sub_issues - if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"): + if ( + not sub_issues_response_json.get("data", {}) + .get("node", {}) + .get("subIssues") + ): get_logger().error("Invalid sub-issues response structure") return sub_issues - - nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", []) - get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}) + + nodes = ( + sub_issues_response_json.get("data", {}) + .get("node", {}) + .get("subIssues", {}) + .get("nodes", []) + ) + get_logger().info( + f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes} + ) for sub_issue in nodes: if "url" in sub_issue: @@ -977,7 +1178,7 @@ class GithubProvider(GitProvider): return False def calc_pr_statistics(self, pull_request_data: dict): - return {} + return {} def validate_comments_inside_hunks(self, code_suggestions): """ @@ -986,7 +1187,8 @@ class GithubProvider(GitProvider): code_suggestions_copy = copy.deepcopy(code_suggestions) diff_files = self.get_diff_files() RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)" + ) diff_files = set_file_languages(diff_files) @@ -995,7 +1197,6 @@ class GithubProvider(GitProvider): relevant_file_path = suggestion['relevant_file'] for file in diff_files: if file.filename == relevant_file_path: - # generate on-demand the patches range for the relevant file patch_str = file.patch if not hasattr(file, 'patches_range'): @@ -1006,14 +1207,30 @@ class GithubProvider(GitProvider): match = RE_HUNK_HEADER.match(line) # identify hunk header if match: - section_header, size1, size2, start1, start2 = extract_hunk_headers(match) - file.patches_range.append({'start': start2, 'end': start2 + size2 - 1}) + ( + section_header, + size1, + size2, + start1, + start2, + ) = extract_hunk_headers(match) + file.patches_range.append( + {'start': start2, 'end': start2 + size2 - 1} + ) patches_range = file.patches_range - comment_start_line = suggestion.get('relevant_lines_start', None) + comment_start_line = suggestion.get( + 'relevant_lines_start', None + ) comment_end_line = suggestion.get('relevant_lines_end', None) - original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code - if not comment_start_line or not comment_end_line or not original_suggestion: + original_suggestion = suggestion.get( + 'original_suggestion', None + ) # needed for diff code + if ( + not comment_start_line + or not comment_end_line + or not original_suggestion + ): continue # check if the comment is inside a valid hunk @@ -1037,30 +1254,57 @@ class GithubProvider(GitProvider): patch_range_min = patch_range min_distance = min(min_distance, d) if not is_valid_hunk: - if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk + if ( + min_distance < 10 + ): # 10 lines - a reasonable distance to consider the comment inside the hunk # make the suggestion non-committable, yet multi line - suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start']) - suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end']) + suggestion['relevant_lines_start'] = max( + suggestion['relevant_lines_start'], + patch_range_min['start'], + ) + suggestion['relevant_lines_end'] = min( + suggestion['relevant_lines_end'], + patch_range_min['end'], + ) body = suggestion['body'].strip() # present new diff code in collapsible - existing_code = original_suggestion['existing_code'].rstrip() + "\n" - improved_code = original_suggestion['improved_code'].rstrip() + "\n" - diff = difflib.unified_diff(existing_code.split('\n'), - improved_code.split('\n'), n=999) + existing_code = ( + original_suggestion['existing_code'].rstrip() + "\n" + ) + improved_code = ( + original_suggestion['improved_code'].rstrip() + "\n" + ) + diff = difflib.unified_diff( + existing_code.split('\n'), + improved_code.split('\n'), + n=999, + ) patch_orig = "\n".join(diff) - patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') + patch = "\n".join(patch_orig.splitlines()[5:]).strip( + '\n' + ) diff_code = f"\n\n
新提议的代码:\n\n```diff\n{patch.rstrip()}\n```" # replace ```suggestion ... ``` with diff_code, using regex: - body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) + body = re.sub( + r'```suggestion.*?```', + diff_code, + body, + flags=re.DOTALL, + ) body += "\n\n
" suggestion['body'] = body - get_logger().info(f"Comment was moved to a valid hunk, " - f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") + get_logger().info( + f"Comment was moved to a valid hunk, " + f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}" + ) else: - get_logger().error(f"Comment is not inside a valid hunk, " - f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") + get_logger().error( + f"Comment is not inside a valid hunk, " + f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}" + ) except Exception as e: - get_logger().error(f"Failed to process patch for committable comment, error: {e}") + get_logger().error( + f"Failed to process patch for committable comment, error: {e}" + ) return code_suggestions_copy - diff --git a/apps/utils/pr_agent/git_providers/gitlab_provider.py b/apps/utils/pr_agent/git_providers/gitlab_provider.py index 3d630c5..0217846 100644 --- a/apps/utils/pr_agent/git_providers/gitlab_provider.py +++ b/apps/utils/pr_agent/git_providers/gitlab_provider.py @@ -10,9 +10,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo from ..algo.file_filter import filter_ignored from ..algo.language_handler import is_valid_file -from ..algo.utils import (clip_tokens, - find_line_number_of_relevant_line_in_file, - load_large_diff) +from ..algo.utils import ( + clip_tokens, + find_line_number_of_relevant_line_in_file, + load_large_diff, +) from ..config_loader import get_settings from ..log import get_logger from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider @@ -20,22 +22,26 @@ from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider class DiffNotFoundError(Exception): """Raised when the diff for a merge request cannot be found.""" + pass -class GitLabProvider(GitProvider): - def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): +class GitLabProvider(GitProvider): + def __init__( + self, + merge_request_url: Optional[str] = None, + incremental: Optional[bool] = False, + ): gitlab_url = get_settings().get("GITLAB.URL", None) if not gitlab_url: raise ValueError("GitLab URL is not set in the config file") self.gitlab_url = gitlab_url gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) if not gitlab_access_token: - raise ValueError("GitLab personal access token is not set in the config file") - self.gl = gitlab.Gitlab( - url=gitlab_url, - oauth_token=gitlab_access_token - ) + raise ValueError( + "GitLab personal access token is not set in the config file" + ) + self.gl = gitlab.Gitlab(url=gitlab_url, oauth_token=gitlab_access_token) self.max_comment_chars = 65000 self.id_project = None self.id_mr = None @@ -46,12 +52,17 @@ class GitLabProvider(GitProvider): self.pr_url = merge_request_url self._set_merge_request(merge_request_url) self.RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)" + ) self.incremental = incremental def is_supported(self, capability: str) -> bool: - if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', - 'publish_file_comments']: # gfm_markdown is supported in gitlab ! + if capability in [ + 'get_issue_comments', + 'create_inline_comment', + 'publish_inline_comments', + 'publish_file_comments', + ]: # gfm_markdown is supported in gitlab ! return False return True @@ -67,12 +78,17 @@ class GitLabProvider(GitProvider): self.last_diff = self.mr.diffs.list(get_all=True)[-1] except IndexError as e: get_logger().error(f"Could not get diff for merge request {self.id_mr}") - raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e - + raise DiffNotFoundError( + f"Could not get diff for merge request {self.id_mr}" + ) from e def get_pr_file_content(self, file_path: str, branch: str) -> str: try: - return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode() + return ( + self.gl.projects.get(self.id_project) + .files.get(file_path, branch) + .decode() + ) except GitlabGetError: # In case of file creation the method returns GitlabGetError (404 file not found). # In this case we return an empty string for the diff. @@ -98,10 +114,13 @@ class GitLabProvider(GitProvider): try: names_original = [diff['new_path'] for diff in diffs_original] names_filtered = [diff['new_path'] for diff in diffs] - get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={ - 'original_files': names_original, - 'filtered_files': names_filtered - }) + get_logger().info( + f"Filtered out [ignore] files for merge request {self.id_mr}", + extra={ + 'original_files': names_original, + 'filtered_files': names_filtered, + }, + ) except Exception as e: pass @@ -116,22 +135,31 @@ class GitLabProvider(GitProvider): # allow only a limited number of files to be fully loaded. We can manage the rest with diffs only counter_valid += 1 if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']: - original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha']) - new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha']) + original_file_content_str = self.get_pr_file_content( + diff['old_path'], self.mr.diff_refs['base_sha'] + ) + new_file_content_str = self.get_pr_file_content( + diff['new_path'], self.mr.diff_refs['head_sha'] + ) else: if counter_valid == MAX_FILES_ALLOWED_FULL: - get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files") + get_logger().info( + f"Too many files in PR, will avoid loading full content for rest of files" + ) original_file_content_str = '' new_file_content_str = '' try: if isinstance(original_file_content_str, bytes): - original_file_content_str = bytes.decode(original_file_content_str, 'utf-8') + original_file_content_str = bytes.decode( + original_file_content_str, 'utf-8' + ) if isinstance(new_file_content_str, bytes): new_file_content_str = bytes.decode(new_file_content_str, 'utf-8') except UnicodeDecodeError: get_logger().warning( - f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}") + f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}" + ) edit_type = EDIT_TYPE.MODIFIED if diff['new_file']: @@ -144,30 +172,43 @@ class GitLabProvider(GitProvider): filename = diff['new_path'] patch = diff['diff'] if not patch: - patch = load_large_diff(filename, new_file_content_str, original_file_content_str) - + patch = load_large_diff( + filename, new_file_content_str, original_file_content_str + ) # count number of lines added and removed patch_lines = patch.splitlines(keepends=True) num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) - num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) + num_minus_lines = len( + [line for line in patch_lines if line.startswith('-')] + ) diff_files.append( - FilePatchInfo(original_file_content_str, new_file_content_str, - patch=patch, - filename=filename, - edit_type=edit_type, - old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'], - num_plus_lines=num_plus_lines, - num_minus_lines=num_minus_lines, )) + FilePatchInfo( + original_file_content_str, + new_file_content_str, + patch=patch, + filename=filename, + edit_type=edit_type, + old_filename=None + if diff['old_path'] == diff['new_path'] + else diff['old_path'], + num_plus_lines=num_plus_lines, + num_minus_lines=num_minus_lines, + ) + ) if invalid_files_names: - get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}") + get_logger().info( + f"Filtered out files with invalid extensions: {invalid_files_names}" + ) self.diff_files = diff_files return diff_files def get_files(self) -> list: if not self.git_files: - self.git_files = [change['new_path'] for change in self.mr.changes()['changes']] + self.git_files = [ + change['new_path'] for change in self.mr.changes()['changes'] + ] return self.git_files def publish_description(self, pr_title: str, pr_body: str): @@ -176,7 +217,9 @@ class GitLabProvider(GitProvider): self.mr.description = pr_body self.mr.save() except Exception as e: - get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}") + get_logger().exception( + f"Could not update merge request {self.id_mr} description: {e}" + ) def get_latest_commit_url(self): return self.mr.commits().next().web_url @@ -184,16 +227,23 @@ class GitLabProvider(GitProvider): def get_comment_url(self, comment): return f"{self.mr.web_url}#note_{comment.id}" - def publish_persistent_comment(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): - self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) + def publish_persistent_comment( + self, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + ): + self.publish_persistent_comment_full( + pr_comment, initial_header, update_header, name, final_update_message + ) def publish_comment(self, mr_comment: str, is_temporary: bool = False): if is_temporary and not get_settings().config.publish_output_progress: - get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}") + get_logger().debug( + f"Skipping publish_comment for temporary comment: {mr_comment}" + ) return None mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars) comment = self.mr.notes.create({'body': mr_comment}) @@ -203,7 +253,7 @@ class GitLabProvider(GitProvider): def edit_comment(self, comment, body: str): body = self.limit_output_characters(body, self.max_comment_chars) - self.mr.notes.update(comment.id,{'body': body} ) + self.mr.notes.update(comment.id, {'body': body}) def edit_comment_from_comment_id(self, comment_id: int, body: str): body = self.limit_output_characters(body, self.max_comment_chars) @@ -216,39 +266,87 @@ class GitLabProvider(GitProvider): discussion = self.mr.discussions.get(comment_id) discussion.notes.create({'body': body}) - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): body = self.limit_output_characters(body, self.max_comment_chars) - edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, - relevant_line_in_file) - self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, - target_file, target_line_no, original_suggestion) + ( + edit_type, + found, + source_line_no, + target_file, + target_line_no, + ) = self.search_line(relevant_file, relevant_line_in_file) + self.send_inline_comment( + body, + edit_type, + found, + relevant_file, + relevant_line_in_file, + source_line_no, + target_file, + target_line_no, + original_suggestion, + ) - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None): - raise NotImplementedError("Gitlab provider does not support creating inline comments yet") + def create_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + absolute_position: int = None, + ): + raise NotImplementedError( + "Gitlab provider does not support creating inline comments yet" + ) def create_inline_comments(self, comments: list[dict]): - raise NotImplementedError("Gitlab provider does not support publishing inline comments yet") + raise NotImplementedError( + "Gitlab provider does not support publishing inline comments yet" + ) def get_comment_body_from_comment_id(self, comment_id: int): comment = self.mr.notes.get(comment_id).body return comment - def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str, - relevant_line_in_file: str, - source_line_no: int, target_file: str, target_line_no: int, - original_suggestion=None) -> None: + def send_inline_comment( + self, + body: str, + edit_type: str, + found: bool, + relevant_file: str, + relevant_line_in_file: str, + source_line_no: int, + target_file: str, + target_line_no: int, + original_suggestion=None, + ) -> None: if not found: - get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + get_logger().info( + f"Could not find position for {relevant_file} {relevant_line_in_file}" + ) else: # in order to have exact sha's we have to find correct diff for this change diff = self.get_relevant_diff(relevant_file, relevant_line_in_file) if diff is None: get_logger().error(f"Could not get diff for merge request {self.id_mr}") - raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") - pos_obj = {'position_type': 'text', - 'new_path': target_file.filename, - 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, - 'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha} + raise DiffNotFoundError( + f"Could not get diff for merge request {self.id_mr}" + ) + pos_obj = { + 'position_type': 'text', + 'new_path': target_file.filename, + 'old_path': target_file.old_filename + if target_file.old_filename + else target_file.filename, + 'base_sha': diff.base_commit_sha, + 'start_sha': diff.start_commit_sha, + 'head_sha': diff.head_commit_sha, + } if edit_type == 'deletion': pos_obj['old_line'] = source_line_no - 1 elif edit_type == 'addition': @@ -256,15 +354,21 @@ class GitLabProvider(GitProvider): else: pos_obj['new_line'] = target_line_no - 1 pos_obj['old_line'] = source_line_no - 1 - get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}") + get_logger().debug( + f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}" + ) try: self.mr.discussions.create({'body': body, 'position': pos_obj}) except Exception as e: try: # fallback - create a general note on the file in the MR if 'suggestion_orig_location' in original_suggestion: - line_start = original_suggestion['suggestion_orig_location']['start_line'] - line_end = original_suggestion['suggestion_orig_location']['end_line'] + line_start = original_suggestion['suggestion_orig_location'][ + 'start_line' + ] + line_end = original_suggestion['suggestion_orig_location'][ + 'end_line' + ] old_code_snippet = original_suggestion['prev_code_snippet'] new_code_snippet = original_suggestion['new_code_snippet'] content = original_suggestion['suggestion_summary'] @@ -287,36 +391,49 @@ class GitLabProvider(GitProvider): else: language = '' link = self.get_line_link(relevant_file, line_start, line_end) - body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n" - body_fallback +=f"\n\n
[{target_file.filename} [{line_start}-{line_end}]]({link}):\n\n" + body_fallback = ( + f"**Suggestion:** {content} [{label}, importance: {score}]\n\n" + ) + body_fallback += f"\n\n
[{target_file.filename} [{line_start}-{line_end}]]({link}):\n\n" body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`" - body_fallback+="
\n\n" - diff_patch = difflib.unified_diff(old_code_snippet.split('\n'), - new_code_snippet.split('\n'), n=999) + body_fallback += "
\n\n" + diff_patch = difflib.unified_diff( + old_code_snippet.split('\n'), + new_code_snippet.split('\n'), + n=999, + ) patch_orig = "\n".join(diff_patch) patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" body_fallback += diff_code # Create a general note on the file in the MR - self.mr.notes.create({ - 'body': body_fallback, - 'position': { - 'base_sha': diff.base_commit_sha, - 'start_sha': diff.start_commit_sha, - 'head_sha': diff.head_commit_sha, - 'position_type': 'text', - 'file_path': f'{target_file.filename}', + self.mr.notes.create( + { + 'body': body_fallback, + 'position': { + 'base_sha': diff.base_commit_sha, + 'start_sha': diff.start_commit_sha, + 'head_sha': diff.head_commit_sha, + 'position_type': 'text', + 'file_path': f'{target_file.filename}', + }, } - }) - get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}") + ) + get_logger().debug( + f"Created fallback comment in MR {self.id_mr} with position {pos_obj}" + ) # get_logger().debug( # f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)") except Exception as e: - get_logger().exception(f"Failed to create comment in MR {self.id_mr}") + get_logger().exception( + f"Failed to create comment in MR {self.id_mr}" + ) - def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]: + def get_relevant_diff( + self, relevant_file: str, relevant_line_in_file: str + ) -> Optional[dict]: changes = self.mr.changes() # Retrieve the changes for the merge request once if not changes: get_logger().error('No changes found for the merge request.') @@ -327,10 +444,14 @@ class GitLabProvider(GitProvider): return None for diff in all_diffs: for change in changes['changes']: - if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']: + if ( + change['new_path'] == relevant_file + and relevant_line_in_file in change['diff'] + ): return diff get_logger().debug( - f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.') + f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.' + ) return self.last_diff # fallback to last_diff if no relevant diff is found def publish_code_suggestions(self, code_suggestions: list) -> bool: @@ -352,7 +473,7 @@ class GitLabProvider(GitProvider): if file.filename == relevant_file: target_file = file break - range = relevant_lines_end - relevant_lines_start # no need to add 1 + range = relevant_lines_end - relevant_lines_start # no need to add 1 body = body.replace('```suggestion', f'```suggestion:-0+{range}') lines = target_file.head_file.splitlines() relevant_line_in_file = lines[relevant_lines_start - 1] @@ -365,10 +486,21 @@ class GitLabProvider(GitProvider): found = True edit_type = 'addition' - self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, - target_file, target_line_no, original_suggestion) + self.send_inline_comment( + body, + edit_type, + found, + relevant_file, + relevant_line_in_file, + source_line_no, + target_file, + target_line_no, + original_suggestion, + ) except Exception as e: - get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}") + get_logger().exception( + f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}" + ) # note that we publish suggestions one-by-one. so, if one fails, the rest will still be published return True @@ -382,8 +514,13 @@ class GitLabProvider(GitProvider): edit_type = self.get_edit_type(relevant_line_in_file) for file in self.get_diff_files(): if file.filename == relevant_file: - edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file, - relevant_line_in_file) + ( + edit_type, + found, + source_line_no, + target_file, + target_line_no, + ) = self.find_in_file(file, relevant_line_in_file) return edit_type, found, source_line_no, target_file, target_line_no def find_in_file(self, file, relevant_line_in_file): @@ -414,7 +551,10 @@ class GitLabProvider(GitProvider): found = True edit_type = self.get_edit_type(line) break - elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line: + elif ( + relevant_line_in_file[0] == '+' + and relevant_line_in_file[1:].lstrip() in line + ): # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally # it's a context line found = True @@ -470,7 +610,11 @@ class GitLabProvider(GitProvider): def get_repo_settings(self): try: - contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode() + contents = ( + self.gl.projects.get(self.id_project) + .files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch) + .decode() + ) return contents except Exception: return "" @@ -478,7 +622,9 @@ class GitLabProvider(GitProvider): def get_workspace_name(self): return self.id_project.split('/')[0] - def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: + def add_eyes_reaction( + self, issue_comment_id: int, disable_eyes: bool = False + ) -> Optional[int]: return True def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: @@ -489,7 +635,9 @@ class GitLabProvider(GitProvider): path_parts = parsed_url.path.strip('/').split('/') if 'merge_requests' not in path_parts: - raise ValueError("The provided URL does not appear to be a GitLab merge request URL") + raise ValueError( + "The provided URL does not appear to be a GitLab merge request URL" + ) mr_index = path_parts.index('merge_requests') # Ensure there is an ID after 'merge_requests' @@ -541,8 +689,15 @@ class GitLabProvider(GitProvider): """ max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: - commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] - commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) + commit_messages_list = [ + commit['message'] for commit in self.mr.commits()._list + ] + commit_messages_str = "\n".join( + [ + f"{i + 1}. {message}" + for i, message in enumerate(commit_messages_list) + ] + ) except Exception: commit_messages_str = "" if max_tokens: @@ -556,7 +711,12 @@ class GitLabProvider(GitProvider): except: return "" - def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: + def get_line_link( + self, + relevant_file: str, + relevant_line_start: int, + relevant_line_end: int = None, + ) -> str: if relevant_line_start == -1: link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads" elif relevant_line_end: @@ -565,7 +725,6 @@ class GitLabProvider(GitProvider): link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}" return link - def generate_link_to_relevant_line_number(self, suggestion) -> str: try: relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip() @@ -573,8 +732,9 @@ class GitLabProvider(GitProvider): if not relevant_line_str: return "" - position, absolute_position = find_line_number_of_relevant_line_in_file \ - (self.diff_files, relevant_file, relevant_line_str) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.diff_files, relevant_file, relevant_line_str + ) if absolute_position != -1: # link to right file only diff --git a/apps/utils/pr_agent/git_providers/local_git_provider.py b/apps/utils/pr_agent/git_providers/local_git_provider.py index 0571bcf..cbf08db 100644 --- a/apps/utils/pr_agent/git_providers/local_git_provider.py +++ b/apps/utils/pr_agent/git_providers/local_git_provider.py @@ -39,10 +39,16 @@ class LocalGitProvider(GitProvider): self._prepare_repo() self.diff_files = None self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) - self.description_path = get_settings().get('local.description_path') \ - if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md' - self.review_path = get_settings().get('local.review_path') \ - if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md' + self.description_path = ( + get_settings().get('local.description_path') + if get_settings().get('local.description_path') is not None + else self.repo_path / 'description.md' + ) + self.review_path = ( + get_settings().get('local.review_path') + if get_settings().get('local.review_path') is not None + else self.repo_path / 'review.md' + ) # inline code comments are not supported for local git repositories get_settings().pr_reviewer.inline_code_comments = False @@ -52,30 +58,43 @@ class LocalGitProvider(GitProvider): """ get_logger().debug('Preparing repository for PR-mimic generation...') if self.repo.is_dirty(): - raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.') + raise ValueError( + 'The repository is not in a clean state. Please commit or stash pending changes.' + ) if self.target_branch_name not in self.repo.heads: raise KeyError(f'Branch: {self.target_branch_name} does not exist') def is_supported(self, capability: str) -> bool: - if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels', - 'gfm_markdown']: + if capability in [ + 'get_issue_comments', + 'create_inline_comment', + 'publish_inline_comments', + 'get_labels', + 'gfm_markdown', + ]: return False return True def get_diff_files(self) -> list[FilePatchInfo]: diffs = self.repo.head.commit.diff( - self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), + self.repo.merge_base( + self.repo.head, self.repo.branches[self.target_branch_name] + ), create_patch=True, - R=True + R=True, ) diff_files = [] for diff_item in diffs: if diff_item.a_blob is not None: - original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8') + original_file_content_str = diff_item.a_blob.data_stream.read().decode( + 'utf-8' + ) else: original_file_content_str = "" # empty file if diff_item.b_blob is not None: - new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8') + new_file_content_str = diff_item.b_blob.data_stream.read().decode( + 'utf-8' + ) else: new_file_content_str = "" # empty file edit_type = EDIT_TYPE.MODIFIED @@ -86,13 +105,16 @@ class LocalGitProvider(GitProvider): elif diff_item.renamed_file: edit_type = EDIT_TYPE.RENAMED diff_files.append( - FilePatchInfo(original_file_content_str, - new_file_content_str, - diff_item.diff.decode('utf-8'), - diff_item.b_path, - edit_type=edit_type, - old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path - ) + FilePatchInfo( + original_file_content_str, + new_file_content_str, + diff_item.diff.decode('utf-8'), + diff_item.b_path, + edit_type=edit_type, + old_filename=None + if diff_item.a_path == diff_item.b_path + else diff_item.a_path, + ) ) self.diff_files = diff_files return diff_files @@ -102,8 +124,10 @@ class LocalGitProvider(GitProvider): Returns a list of files with changes in the diff. """ diff_index = self.repo.head.commit.diff( - self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), - R=True + self.repo.merge_base( + self.repo.head, self.repo.branches[self.target_branch_name] + ), + R=True, ) # Get the list of changed files diff_files = [item.a_path for item in diff_index] @@ -119,18 +143,37 @@ class LocalGitProvider(GitProvider): # Write the string to the file file.write(pr_comment) - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): - raise NotImplementedError('Publishing inline comments is not implemented for the local git provider') + def publish_inline_comment( + self, + body: str, + relevant_file: str, + relevant_line_in_file: str, + original_suggestion=None, + ): + raise NotImplementedError( + 'Publishing inline comments is not implemented for the local git provider' + ) def publish_inline_comments(self, comments: list[dict]): - raise NotImplementedError('Publishing inline comments is not implemented for the local git provider') + raise NotImplementedError( + 'Publishing inline comments is not implemented for the local git provider' + ) - def publish_code_suggestion(self, body: str, relevant_file: str, - relevant_lines_start: int, relevant_lines_end: int): - raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') + def publish_code_suggestion( + self, + body: str, + relevant_file: str, + relevant_lines_start: int, + relevant_lines_end: int, + ): + raise NotImplementedError( + 'Publishing code suggestions is not implemented for the local git provider' + ) def publish_code_suggestions(self, code_suggestions: list) -> bool: - raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') + raise NotImplementedError( + 'Publishing code suggestions is not implemented for the local git provider' + ) def publish_labels(self, labels): pass # Not applicable to the local git provider, but required by the interface @@ -158,19 +201,31 @@ class LocalGitProvider(GitProvider): Calculate percentage of languages in repository. Used for hunk prioritisation. """ # Get all files in repository - filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob'] + filepaths = [ + Path(item.path) + for item in self.repo.tree().traverse() + if item.type == 'blob' + ] # Identify language by file extension and count - lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()]) + lang_count = Counter( + ext.lstrip('.') + for filepath in filepaths + for ext in [filepath.suffix.lower()] + ) # Convert counts to percentages total_files = len(filepaths) - lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()} + lang_percentage = { + lang: count / total_files * 100 for lang, count in lang_count.items() + } return lang_percentage def get_pr_branch(self): return self.repo.head def get_user_id(self): - return -1 # Not used anywhere for the local provider, but required by the interface + return ( + -1 + ) # Not used anywhere for the local provider, but required by the interface def get_pr_description_full(self): commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD')) @@ -186,7 +241,11 @@ class LocalGitProvider(GitProvider): return self.head_branch_name def get_issue_comments(self): - raise NotImplementedError('Getting issue comments is not implemented for the local git provider') + raise NotImplementedError( + 'Getting issue comments is not implemented for the local git provider' + ) def get_pr_labels(self, update=False): - raise NotImplementedError('Getting labels is not implemented for the local git provider') + raise NotImplementedError( + 'Getting labels is not implemented for the local git provider' + ) diff --git a/apps/utils/pr_agent/git_providers/utils.py b/apps/utils/pr_agent/git_providers/utils.py index 1693c34..41ec496 100644 --- a/apps/utils/pr_agent/git_providers/utils.py +++ b/apps/utils/pr_agent/git_providers/utils.py @@ -6,7 +6,7 @@ from dynaconf import Dynaconf from starlette_context import context from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.git_providers import (get_git_provider_with_context) +from utils.pr_agent.git_providers import get_git_provider_with_context from utils.pr_agent.log import get_logger @@ -20,7 +20,9 @@ def apply_repo_settings(pr_url): except Exception: repo_settings = None pass - if repo_settings is None: # None is different from "", which is a valid value + if ( + repo_settings is None + ): # None is different from "", which is a valid value repo_settings = git_provider.get_repo_settings() try: context["repo_settings"] = repo_settings @@ -36,15 +38,25 @@ def apply_repo_settings(pr_url): os.write(fd, repo_settings) new_settings = Dynaconf(settings_files=[repo_settings_file]) for section, contents in new_settings.as_dict().items(): - section_dict = copy.deepcopy(get_settings().as_dict().get(section, {})) + section_dict = copy.deepcopy( + get_settings().as_dict().get(section, {}) + ) for key, value in contents.items(): section_dict[key] = value get_settings().unset(section) get_settings().set(section, section_dict, merge=False) - get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}") + get_logger().info( + f"Applying repo settings:\n{new_settings.as_dict()}" + ) except Exception as e: - get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}") - error_local = {'error': str(e), 'settings': repo_settings, 'category': category} + get_logger().warning( + f"Failed to apply repo {category} settings, error: {str(e)}" + ) + error_local = { + 'error': str(e), + 'settings': repo_settings, + 'category': category, + } if error_local: handle_configurations_errors([error_local], git_provider) @@ -55,7 +67,10 @@ def apply_repo_settings(pr_url): try: os.remove(repo_settings_file) except Exception as e: - get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e) + get_logger().error( + f"Failed to remove temporary settings file {repo_settings_file}", + e, + ) # enable switching models with a short definition if get_settings().config.model.lower() == 'claude-3-5-sonnet': @@ -79,13 +94,18 @@ def handle_configurations_errors(config_errors, git_provider): body += f"\n\n
配置内容:\n\n```toml\n{configuration_file_content}\n```\n\n
" else: body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n" - get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body}) + get_logger().warning( + f"Sending a 'configuration error' comment to the PR", + artifact={'body': body}, + ) # git_provider.publish_comment(body) if hasattr(git_provider, 'publish_persistent_comment'): - git_provider.publish_persistent_comment(body, - initial_header=header, - update_header=False, - final_update_message=False) + git_provider.publish_persistent_comment( + body, + initial_header=header, + update_header=False, + final_update_message=False, + ) else: git_provider.publish_comment(body) except Exception as e: diff --git a/apps/utils/pr_agent/identity_providers/__init__.py b/apps/utils/pr_agent/identity_providers/__init__.py index f816170..c7b5a3a 100644 --- a/apps/utils/pr_agent/identity_providers/__init__.py +++ b/apps/utils/pr_agent/identity_providers/__init__.py @@ -1,10 +1,9 @@ from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.identity_providers.default_identity_provider import \ - DefaultIdentityProvider +from utils.pr_agent.identity_providers.default_identity_provider import ( + DefaultIdentityProvider, +) -_IDENTITY_PROVIDERS = { - 'default': DefaultIdentityProvider -} +_IDENTITY_PROVIDERS = {'default': DefaultIdentityProvider} def get_identity_provider(): diff --git a/apps/utils/pr_agent/identity_providers/default_identity_provider.py b/apps/utils/pr_agent/identity_providers/default_identity_provider.py index d30f17e..dee0fb8 100644 --- a/apps/utils/pr_agent/identity_providers/default_identity_provider.py +++ b/apps/utils/pr_agent/identity_providers/default_identity_provider.py @@ -1,5 +1,7 @@ -from utils.pr_agent.identity_providers.identity_provider import (Eligibility, - IdentityProvider) +from utils.pr_agent.identity_providers.identity_provider import ( + Eligibility, + IdentityProvider, +) class DefaultIdentityProvider(IdentityProvider): diff --git a/apps/utils/pr_agent/log/__init__.py b/apps/utils/pr_agent/log/__init__.py index 658e1e7..c9686ca 100644 --- a/apps/utils/pr_agent/log/__init__.py +++ b/apps/utils/pr_agent/log/__init__.py @@ -30,7 +30,9 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE if type(level) is not int: level = logging.INFO - if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app + if ( + fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0" + ): # better debugging github_app logger.remove(None) logger.add( sys.stdout, @@ -40,7 +42,7 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE colorize=False, serialize=True, ) - elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields + elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields logger.remove(None) logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter) diff --git a/apps/utils/pr_agent/secret_providers/__init__.py b/apps/utils/pr_agent/secret_providers/__init__.py index cfd3e5d..3839216 100644 --- a/apps/utils/pr_agent/secret_providers/__init__.py +++ b/apps/utils/pr_agent/secret_providers/__init__.py @@ -8,10 +8,14 @@ def get_secret_provider(): provider_id = get_settings().config.secret_provider if provider_id == 'google_cloud_storage': try: - from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \ - GoogleCloudStorageSecretProvider + from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import ( + GoogleCloudStorageSecretProvider, + ) + return GoogleCloudStorageSecretProvider() except Exception as e: - raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e + raise ValueError( + f"Failed to initialize google_cloud_storage secret provider {provider_id}" + ) from e else: raise ValueError("Unknown SECRET_PROVIDER") diff --git a/apps/utils/pr_agent/secret_providers/google_cloud_storage_secret_provider.py b/apps/utils/pr_agent/secret_providers/google_cloud_storage_secret_provider.py index 9784d47..6240393 100644 --- a/apps/utils/pr_agent/secret_providers/google_cloud_storage_secret_provider.py +++ b/apps/utils/pr_agent/secret_providers/google_cloud_storage_secret_provider.py @@ -9,12 +9,15 @@ from utils.pr_agent.secret_providers.secret_provider import SecretProvider class GoogleCloudStorageSecretProvider(SecretProvider): def __init__(self): try: - self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage. - service_account)) + self.client = storage.Client.from_service_account_info( + ujson.loads(get_settings().google_cloud_storage.service_account) + ) self.bucket_name = get_settings().google_cloud_storage.bucket_name self.bucket = self.client.bucket(self.bucket_name) except Exception as e: - get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}") + get_logger().error( + f"Failed to initialize Google Cloud Storage Secret Provider: {e}" + ) raise e def get_secret(self, secret_name: str) -> str: @@ -22,7 +25,9 @@ class GoogleCloudStorageSecretProvider(SecretProvider): blob = self.bucket.blob(secret_name) return blob.download_as_string() except Exception as e: - get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}") + get_logger().warning( + f"Failed to get secret {secret_name} from Google Cloud Storage: {e}" + ) return "" def store_secret(self, secret_name: str, secret_value: str): @@ -30,5 +35,7 @@ class GoogleCloudStorageSecretProvider(SecretProvider): blob = self.bucket.blob(secret_name) blob.upload_from_string(secret_value) except Exception as e: - get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}") + get_logger().error( + f"Failed to store secret {secret_name} in Google Cloud Storage: {e}" + ) raise e diff --git a/apps/utils/pr_agent/secret_providers/secret_provider.py b/apps/utils/pr_agent/secret_providers/secret_provider.py index df1e778..80a5b10 100644 --- a/apps/utils/pr_agent/secret_providers/secret_provider.py +++ b/apps/utils/pr_agent/secret_providers/secret_provider.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod class SecretProvider(ABC): - @abstractmethod def get_secret(self, secret_name: str) -> str: pass diff --git a/apps/utils/pr_agent/servers/azuredevops_server_webhook.py b/apps/utils/pr_agent/servers/azuredevops_server_webhook.py index e77f977..a98abb3 100644 --- a/apps/utils/pr_agent/servers/azuredevops_server_webhook.py +++ b/apps/utils/pr_agent/servers/azuredevops_server_webhook.py @@ -33,6 +33,7 @@ azure_devops_server = get_settings().get("azure_devops_server") WEBHOOK_USERNAME = azure_devops_server.get("webhook_username") WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password") + def handle_request( background_tasks: BackgroundTasks, url: str, body: str, log_context: dict ): @@ -52,20 +53,27 @@ def handle_request( # currently only basic auth is supported with azure webhooks # for this reason, https must be enabled to ensure the credentials are not sent in clear text def authorize(credentials: HTTPBasicCredentials = Depends(security)): - is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME) - is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD) - if not (is_user_ok and is_pass_ok): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail='Incorrect username or password.', - headers={'WWW-Authenticate': 'Basic'}, - ) + is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME) + is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD) + if not (is_user_ok and is_pass_ok): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Incorrect username or password.', + headers={'WWW-Authenticate': 'Basic'}, + ) -async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict): +async def _perform_commands_azure( + commands_conf: str, agent: PRAgent, api_url: str, log_context: dict +): apply_repo_settings(api_url) - if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled - get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context) + if ( + commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback + ): # auto commands for PR, and auto feedback is disabled + get_logger().info( + f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", + **log_context, + ) return commands = get_settings().get(f"azure_devops_server.{commands_conf}") get_settings().set("config.is_auto_command", True) @@ -92,22 +100,38 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request): actions = [] if data["eventType"] == "git.pullrequest.created": # API V1 (latest) - pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git")) + pr_url = unquote( + data["resource"]["_links"]["web"]["href"].replace( + "_apis/git/repositories", "_git" + ) + ) log_context["event"] = data["eventType"] log_context["api_url"] = pr_url await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context) return - elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]: + elif ( + data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" + and "content" in data["resource"]["comment"] + ): if available_commands_rgx.match(data["resource"]["comment"]["content"]): - if(data["resourceVersion"] == "2.0"): + if data["resourceVersion"] == "2.0": repo = data["resource"]["pullRequest"]["repository"]["webUrl"] - pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}') + pr_url = unquote( + f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}' + ) actions = [data["resource"]["comment"]["content"]] else: # API V1 not supported as it does not contain the PR URL - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})), + return ( + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=json.dumps( + { + "message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0" + } + ), + ), + ) else: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, @@ -132,17 +156,21 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request): content=json.dumps({"message": "Internal server error"}), ) return JSONResponse( - status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"}) + status_code=status.HTTP_202_ACCEPTED, + content=jsonable_encoder({"message": "webhook triggered successfully"}), ) + @router.get("/") async def root(): return {"status": "ok"} + def start(): app = FastAPI(middleware=[Middleware(RawContextMiddleware)]) app.include_router(router) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000"))) + if __name__ == "__main__": start() diff --git a/apps/utils/pr_agent/servers/bitbucket_app.py b/apps/utils/pr_agent/servers/bitbucket_app.py index 7fa9ca3..42e3982 100644 --- a/apps/utils/pr_agent/servers/bitbucket_app.py +++ b/apps/utils/pr_agent/servers/bitbucket_app.py @@ -27,7 +27,9 @@ from utils.pr_agent.secret_providers import get_secret_provider setup_logger(fmt=LoggingFormat.JSON, level="DEBUG") router = APIRouter() -secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None +secret_provider = ( + get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None +) async def get_bearer_token(shared_secret: str, client_key: str): @@ -44,12 +46,12 @@ async def get_bearer_token(shared_secret: str, client_key: str): "exp": now + 240, "qsh": qsh, "sub": client_key, - } + } token = jwt.encode(payload, shared_secret, algorithm="HS256") payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt' headers = { 'Authorization': f'JWT {token}', - 'Content-Type': 'application/x-www-form-urlencoded' + 'Content-Type': 'application/x-www-form-urlencoded', } response = requests.request("POST", url, headers=headers, data=payload) bearer_token = response.json()["access_token"] @@ -58,6 +60,7 @@ async def get_bearer_token(shared_secret: str, client_key: str): get_logger().error(f"Failed to get bearer token: {e}") raise e + @router.get("/") async def handle_manifest(request: Request, response: Response): cur_dir = os.path.dirname(os.path.abspath(__file__)) @@ -66,7 +69,9 @@ async def handle_manifest(request: Request, response: Response): manifest = manifest.replace("app_key", get_settings().bitbucket.app_key) manifest = manifest.replace("base_url", get_settings().bitbucket.base_url) except: - get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue") + get_logger().error( + "Failed to replace api_key in Bitbucket manifest, trying to continue" + ) manifest_obj = json.loads(manifest) return JSONResponse(manifest_obj) @@ -83,10 +88,16 @@ def _get_username(data): return "" -async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict): +async def _perform_commands_bitbucket( + commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict +): apply_repo_settings(api_url) - if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled - get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}") + if ( + commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback + ): # auto commands for PR, and auto feedback is disabled + get_logger().info( + f"Auto feedback is disabled, skipping auto commands for PR {api_url=}" + ) return if data.get("event", "") == "pullrequest:created": if not should_process_pr_logic(data): @@ -132,7 +143,9 @@ def should_process_pr_logic(data) -> bool: ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) if ignore_pr_users and sender: if sender in ignore_pr_users: - get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting") + get_logger().info( + f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting" + ) return False # logic to ignore PRs with specific titles @@ -140,20 +153,34 @@ def should_process_pr_logic(data) -> bool: ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) if not isinstance(ignore_pr_title_re, list): ignore_pr_title_re = [ignore_pr_title_re] - if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re): - get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting") + if ignore_pr_title_re and any( + re.search(regex, title) for regex in ignore_pr_title_re + ): + get_logger().info( + f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting" + ) return False - ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) - ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) - if (ignore_pr_source_branches or ignore_pr_target_branches): - if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches): + ignore_pr_source_branches = get_settings().get( + "CONFIG.IGNORE_PR_SOURCE_BRANCHES", [] + ) + ignore_pr_target_branches = get_settings().get( + "CONFIG.IGNORE_PR_TARGET_BRANCHES", [] + ) + if ignore_pr_source_branches or ignore_pr_target_branches: + if any( + re.search(regex, source_branch) for regex in ignore_pr_source_branches + ): get_logger().info( - f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings") + f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings" + ) return False - if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches): + if any( + re.search(regex, target_branch) for regex in ignore_pr_target_branches + ): get_logger().info( - f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings") + f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings" + ) return False except Exception as e: get_logger().error(f"Failed 'should_process_pr_logic': {e}") @@ -195,7 +222,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req client_key = claims["iss"] secrets = json.loads(secret_provider.get_secret(client_key)) shared_secret = secrets["shared_secret"] - jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]) + jwt.decode( + input_jwt, shared_secret, audience=client_key, algorithms=["HS256"] + ) bearer_token = await get_bearer_token(shared_secret, client_key) context['bitbucket_bearer_token'] = bearer_token context["settings"] = copy.deepcopy(global_settings) @@ -208,28 +237,41 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req if pr_url: with get_logger().contextualize(**log_context): apply_repo_settings(pr_url) - if get_identity_provider().verify_eligibility("bitbucket", - sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE: + if ( + get_identity_provider().verify_eligibility( + "bitbucket", sender_id, pr_url + ) + is not Eligibility.NOT_ELIGIBLE + ): if get_settings().get("bitbucket_app.pr_commands"): - await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data) + await _perform_commands_bitbucket( + "pr_commands", PRAgent(), pr_url, log_context, data + ) elif event == "pullrequest:comment_created": pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] log_context["api_url"] = pr_url log_context["event"] = "comment" comment_body = data["data"]["comment"]["content"]["raw"] with get_logger().contextualize(**log_context): - if get_identity_provider().verify_eligibility("bitbucket", - sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE: + if ( + get_identity_provider().verify_eligibility( + "bitbucket", sender_id, pr_url + ) + is not Eligibility.NOT_ELIGIBLE + ): await agent.handle_request(pr_url, comment_body) except Exception as e: get_logger().error(f"Failed to handle webhook: {e}") + background_tasks.add_task(inner) return "OK" + @router.get("/webhook") async def handle_github_webhooks(request: Request, response: Response): return "Webhook server online!" + @router.post("/installed") async def handle_installed_webhooks(request: Request, response: Response): try: @@ -240,15 +282,13 @@ async def handle_installed_webhooks(request: Request, response: Response): shared_secret = data["sharedSecret"] client_key = data["clientKey"] username = data["principal"]["username"] - secrets = { - "shared_secret": shared_secret, - "client_key": client_key - } + secrets = {"shared_secret": shared_secret, "client_key": client_key} secret_provider.store_secret(username, json.dumps(secrets)) except Exception as e: get_logger().error(f"Failed to register user: {e}") return JSONResponse({"error": "Unable to register user"}, status_code=500) + @router.post("/uninstalled") async def handle_uninstalled_webhooks(request: Request, response: Response): get_logger().info("handle_uninstalled_webhooks") diff --git a/apps/utils/pr_agent/servers/bitbucket_server_webhook.py b/apps/utils/pr_agent/servers/bitbucket_server_webhook.py index d291b48..036d9ef 100644 --- a/apps/utils/pr_agent/servers/bitbucket_server_webhook.py +++ b/apps/utils/pr_agent/servers/bitbucket_server_webhook.py @@ -40,10 +40,12 @@ def handle_request( background_tasks.add_task(inner) + @router.post("/") async def redirect_to_webhook(): return RedirectResponse(url="/webhook") + @router.post("/webhook") async def handle_webhook(background_tasks: BackgroundTasks, request: Request): log_context = {"server_type": "bitbucket_server"} @@ -55,7 +57,8 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request): body_bytes = await request.body() if body_bytes.decode('utf-8') == '{"test": true}': return JSONResponse( - status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"}) + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "connection test successful"}), ) signature_header = request.headers.get("x-hub-signature", None) verify_signature(body_bytes, webhook_secret, signature_header) @@ -73,11 +76,18 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request): if data["eventKey"] == "pr:opened": apply_repo_settings(pr_url) - if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled - get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context) + if ( + get_settings().config.disable_auto_feedback + ): # auto commands for PR, and auto feedback is disabled + get_logger().info( + f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", + **log_context, + ) return get_settings().set("config.is_auto_command", True) - commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')) + commands_to_run.extend( + _get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS') + ) elif data["eventKey"] == "pr:comment:added": commands_to_run.append(data["comment"]["text"]) else: @@ -116,6 +126,7 @@ async def _run_commands_sequentially(commands: List[str], url: str, log_context: except Exception as e: get_logger().error(f"Failed to handle command: {command} , error: {e}") + def _process_command(command: str, url) -> str: # don't think we need this apply_repo_settings(url) @@ -142,11 +153,13 @@ def _to_list(command_string: str) -> list: raise ValueError(f"Invalid command string: {e}") -def _get_commands_list_from_settings(setting_key:str ) -> list: +def _get_commands_list_from_settings(setting_key: str) -> list: try: return get_settings().get(setting_key, []) except ValueError as e: - get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}") + get_logger().error( + f"Failed to get commands list from settings {setting_key}: {e}" + ) @router.get("/") diff --git a/apps/utils/pr_agent/servers/gerrit_server.py b/apps/utils/pr_agent/servers/gerrit_server.py index 4831a68..36c5df5 100644 --- a/apps/utils/pr_agent/servers/gerrit_server.py +++ b/apps/utils/pr_agent/servers/gerrit_server.py @@ -40,12 +40,10 @@ async def handle_gerrit_request(action: Action, item: Item): if action == Action.ask: if not item.msg: return HTTPException( - status_code=400, - detail="msg is required for ask command" + status_code=400, detail="msg is required for ask command" ) await PRAgent().handle_request( - f"{item.project}:{item.refspec}", - f"/{item.msg.strip()}" + f"{item.project}:{item.refspec}", f"/{item.msg.strip()}" ) diff --git a/apps/utils/pr_agent/servers/github_action_runner.py b/apps/utils/pr_agent/servers/github_action_runner.py index db50bc6..98b034e 100644 --- a/apps/utils/pr_agent/servers/github_action_runner.py +++ b/apps/utils/pr_agent/servers/github_action_runner.py @@ -26,7 +26,12 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str, try: value = get_settings().get(key, default) except AttributeError: # TBD still need to debug why this happens on GitHub Actions - value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default + value = ( + os.getenv(key, None) + or os.getenv(key.upper(), None) + or os.getenv(key.lower(), None) + or default + ) return value @@ -76,16 +81,24 @@ async def run_action(): pr_url = event_payload.get("pull_request", {}).get("html_url") if pr_url: apply_repo_settings(pr_url) - get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}") + get_logger().info( + f"enable_custom_labels: {get_settings().config.enable_custom_labels}" + ) except Exception as e: get_logger().info(f"github action: failed to apply repo settings: {e}") # Handle pull request opened event - if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target": + if ( + GITHUB_EVENT_NAME == "pull_request" + or GITHUB_EVENT_NAME == "pull_request_target" + ): action = event_payload.get("action") # Retrieve the list of actions from the configuration - pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"]) + pr_actions = get_settings().get( + "GITHUB_ACTION_CONFIG.PR_ACTIONS", + ["opened", "reopened", "ready_for_review", "review_requested"], + ) if action in pr_actions: pr_url = event_payload.get("pull_request", {}).get("url") @@ -93,18 +106,30 @@ async def run_action(): # legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None) if auto_review is None: - auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None) + auto_review = get_setting_or_env( + "GITHUB_ACTION_CONFIG.AUTO_REVIEW", None + ) auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None) if auto_describe is None: - auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None) + auto_describe = get_setting_or_env( + "GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None + ) auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None) if auto_improve is None: - auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None) + auto_improve = get_setting_or_env( + "GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None + ) # Set the configuration for auto actions - get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto - get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled - get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}") + get_settings().config.is_auto_command = ( + True # Set the flag to indicate that the command is auto + ) + get_settings().pr_description.final_update_message = ( + False # No final update message when auto_describe is enabled + ) + get_logger().info( + f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}" + ) # invoke by default all three tools if auto_describe is None or is_true(auto_describe): @@ -117,7 +142,10 @@ async def run_action(): get_logger().info(f"Skipping action: {action}") # Handle issue comment event - elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment": + elif ( + GITHUB_EVENT_NAME == "issue_comment" + or GITHUB_EVENT_NAME == "pull_request_review_comment" + ): action = event_payload.get("action") if action in ["created", "edited"]: comment_body = event_payload.get("comment", {}).get("body") @@ -133,9 +161,15 @@ async def run_action(): disable_eyes = False # check if issue is pull request if event_payload.get("issue", {}).get("pull_request"): - url = event_payload.get("issue", {}).get("pull_request", {}).get("url") + url = ( + event_payload.get("issue", {}) + .get("pull_request", {}) + .get("url") + ) is_pr = True - elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment + elif event_payload.get("comment", {}).get( + "pull_request_url" + ): # for 'pull_request_review_comment url = event_payload.get("comment", {}).get("pull_request_url") is_pr = True disable_eyes = True @@ -148,9 +182,11 @@ async def run_action(): provider = get_git_provider()(pr_url=url) if is_pr: await PRAgent().handle_request( - url, body, notify=lambda: provider.add_eyes_reaction( + url, + body, + notify=lambda: provider.add_eyes_reaction( comment_id, disable_eyes=disable_eyes - ) + ), ) else: await PRAgent().handle_request(url, body) diff --git a/apps/utils/pr_agent/servers/github_app.py b/apps/utils/pr_agent/servers/github_app.py index 58bb871..5162b7e 100644 --- a/apps/utils/pr_agent/servers/github_app.py +++ b/apps/utils/pr_agent/servers/github_app.py @@ -15,8 +15,7 @@ from starlette_context.middleware import RawContextMiddleware from utils.pr_agent.agent.pr_agent import PRAgent from utils.pr_agent.algo.utils import update_settings_from_args from utils.pr_agent.config_loader import get_settings, global_settings -from utils.pr_agent.git_providers import (get_git_provider, - get_git_provider_with_context) +from utils.pr_agent.git_providers import get_git_provider, get_git_provider_with_context from utils.pr_agent.git_providers.utils import apply_repo_settings from utils.pr_agent.identity_providers import get_identity_provider from utils.pr_agent.identity_providers.identity_provider import Eligibility @@ -35,7 +34,9 @@ router = APIRouter() @router.post("/api/v1/github_webhooks") -async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response): +async def handle_github_webhooks( + background_tasks: BackgroundTasks, request: Request, response: Response +): """ Receives and processes incoming GitHub webhook requests. Verifies the request signature, parses the request body, and passes it to the handle_request function for further @@ -49,7 +50,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req context["installation_id"] = installation_id context["settings"] = copy.deepcopy(global_settings) context["git_provider"] = {} - background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None)) + background_tasks.add_task( + handle_request, body, event=request.headers.get("X-GitHub-Event", None) + ) return {} @@ -73,35 +76,61 @@ async def get_body(request): return body -_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl) -_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl) +_duplicate_push_triggers = DefaultDictWithTimeout( + ttl=get_settings().github_app.push_trigger_pending_tasks_ttl +) +_pending_task_duplicate_push_conditions = DefaultDictWithTimeout( + asyncio.locks.Condition, + ttl=get_settings().github_app.push_trigger_pending_tasks_ttl, +) -async def handle_comments_on_pr(body: Dict[str, Any], - event: str, - sender: str, - sender_id: str, - action: str, - log_context: Dict[str, Any], - agent: PRAgent): + +async def handle_comments_on_pr( + body: Dict[str, Any], + event: str, + sender: str, + sender_id: str, + action: str, + log_context: Dict[str, Any], + agent: PRAgent, +): if "comment" not in body: return {} comment_body = body.get("comment", {}).get("body") - if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"): + if ( + comment_body + and isinstance(comment_body, str) + and not comment_body.lstrip().startswith("/") + ): if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'): comment_body_split = comment_body.split('/ask') - comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>') - get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}") + comment_body = ( + '/ask' + + comment_body_split[1] + + ' \n' + + comment_body_split[0].strip().lstrip('>') + ) + get_logger().info( + f"Reformatting comment_body so command is at the beginning: {comment_body}" + ) else: get_logger().info("Ignoring comment not starting with /") return {} disable_eyes = False - if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]: + if ( + "issue" in body + and "pull_request" in body["issue"] + and "url" in body["issue"]["pull_request"] + ): api_url = body["issue"]["pull_request"]["url"] elif "comment" in body and "pull_request_url" in body["comment"]: api_url = body["comment"]["pull_request_url"] try: - if ('/ask' in comment_body and - 'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"): + if ( + '/ask' in comment_body + and 'subject_type' in body["comment"] + and body["comment"]["subject_type"] == "line" + ): # comment on a code line in the "files changed" tab comment_body = handle_line_comments(body, comment_body) disable_eyes = True @@ -113,46 +142,75 @@ async def handle_comments_on_pr(body: Dict[str, Any], comment_id = body.get("comment", {}).get("id") provider = get_git_provider_with_context(pr_url=api_url) with get_logger().contextualize(**log_context): - if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: - get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}") - await agent.handle_request(api_url, comment_body, - notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes)) + if ( + get_identity_provider().verify_eligibility("github", sender_id, api_url) + is not Eligibility.NOT_ELIGIBLE + ): + get_logger().info( + f"Processing comment on PR {api_url=}, comment_body={comment_body}" + ) + await agent.handle_request( + api_url, + comment_body, + notify=lambda: provider.add_eyes_reaction( + comment_id, disable_eyes=disable_eyes + ), + ) else: - get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}") + get_logger().info( + f"User {sender=} is not eligible to process comment on PR {api_url=}" + ) -async def handle_new_pr_opened(body: Dict[str, Any], - event: str, - sender: str, - sender_id: str, - action: str, - log_context: Dict[str, Any], - agent: PRAgent): + +async def handle_new_pr_opened( + body: Dict[str, Any], + event: str, + sender: str, + sender_id: str, + action: str, + log_context: Dict[str, Any], + agent: PRAgent, +): title = body.get("pull_request", {}).get("title", "") pull_request, api_url = _check_pull_request_event(action, body, log_context) if not (pull_request and api_url): get_logger().info(f"Invalid PR event: {action=} {api_url=}") return {} - if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review'] + if ( + action in get_settings().github_app.handle_pr_actions + ): # ['opened', 'reopened', 'ready_for_review'] # logic to ignore PRs with specific titles (e.g. "[Auto] ...") apply_repo_settings(api_url) - if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: - await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context) + if ( + get_identity_provider().verify_eligibility("github", sender_id, api_url) + is not Eligibility.NOT_ELIGIBLE + ): + await _perform_auto_commands_github( + "pr_commands", agent, body, api_url, log_context + ) else: - get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}") + get_logger().info( + f"User {sender=} is not eligible to process PR {api_url=}" + ) -async def handle_push_trigger_for_new_commits(body: Dict[str, Any], - event: str, - sender: str, - sender_id: str, - action: str, - log_context: Dict[str, Any], - agent: PRAgent): + +async def handle_push_trigger_for_new_commits( + body: Dict[str, Any], + event: str, + sender: str, + sender_id: str, + action: str, + log_context: Dict[str, Any], + agent: PRAgent, +): pull_request, api_url = _check_pull_request_event(action, body, log_context) if not (pull_request and api_url): return {} - apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event. + apply_repo_settings( + api_url + ) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event. if not get_settings().github_app.handle_push_trigger: return {} @@ -162,7 +220,10 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any], merge_commit_sha = pull_request.get("merge_commit_sha") if before_sha == after_sha: return {} - if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha: + if ( + get_settings().github_app.push_trigger_ignore_merge_commits + and after_sha == merge_commit_sha + ): return {} # Prevent triggering multiple times for subsequent push triggers when one is enough: @@ -172,7 +233,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any], # more commits may have been pushed that led to the subsequent events, # so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting. current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0) - max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1 + max_active_tasks = ( + 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1 + ) if current_active_tasks < max_active_tasks: # first task can enter, and second tasks too if backlog is enabled get_logger().info( @@ -191,12 +254,21 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any], f"Waiting to process push trigger for {api_url=} because the first task is still in progress" ) await _pending_task_duplicate_push_conditions[api_url].wait() - get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow") + get_logger().info( + f"Finished waiting to process push trigger for {api_url=} - continue with flow" + ) try: - if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: - get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}") - await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context) + if ( + get_identity_provider().verify_eligibility("github", sender_id, api_url) + is not Eligibility.NOT_ELIGIBLE + ): + get_logger().info( + f"Performing incremental review for {api_url=} because of {event=} and {action=}" + ) + await _perform_auto_commands_github( + "push_commands", agent, body, api_url, log_context + ) finally: # release the waiting task block @@ -213,7 +285,12 @@ def handle_closed_pr(body, event, action, log_context): api_url = pull_request.get("url", "") pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request) log_context["api_url"] = api_url - get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context) + get_logger().info( + "PR-Agent statistics for closed PR", + analytics=True, + pr_statistics=pr_statistics, + **log_context, + ) def get_log_context(body, event, action, build_number): @@ -228,9 +305,18 @@ def get_log_context(body, event, action, build_number): git_org = body.get("organization", {}).get("login", "") installation_id = body.get("installation", {}).get("id", "") app_name = get_settings().get("CONFIG.APP_NAME", "Unknown") - log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app", - "request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name, - "repo": repo, "git_org": git_org, "installation_id": installation_id} + log_context = { + "action": action, + "event": event, + "sender": sender, + "server_type": "github_app", + "request_id": uuid.uuid4().hex, + "build_number": build_number, + "app_name": app_name, + "repo": repo, + "git_org": git_org, + "installation_id": installation_id, + } except Exception as e: get_logger().error("Failed to get log context", e) log_context = {} @@ -240,7 +326,10 @@ def get_log_context(body, event, action, build_number): def is_bot_user(sender, sender_type): try: # logic to ignore PRs opened by bot - if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot": + if ( + get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) + and sender_type == "Bot" + ): if 'pr-agent' not in sender: get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot") return True @@ -262,7 +351,9 @@ def should_process_pr_logic(body) -> bool: ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) if ignore_pr_users and sender: if sender in ignore_pr_users: - get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting") + get_logger().info( + f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting" + ) return False # logic to ignore PRs with specific titles @@ -270,8 +361,12 @@ def should_process_pr_logic(body) -> bool: ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) if not isinstance(ignore_pr_title_re, list): ignore_pr_title_re = [ignore_pr_title_re] - if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re): - get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting") + if ignore_pr_title_re and any( + re.search(regex, title) for regex in ignore_pr_title_re + ): + get_logger().info( + f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting" + ) return False # logic to ignore PRs with specific labels or source branches or target branches. @@ -280,20 +375,32 @@ def should_process_pr_logic(body) -> bool: labels = [label['name'] for label in pr_labels] if any(label in ignore_pr_labels for label in labels): labels_str = ", ".join(labels) - get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings") + get_logger().info( + f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings" + ) return False # logic to ignore PRs with specific source or target branches - ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) - ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) + ignore_pr_source_branches = get_settings().get( + "CONFIG.IGNORE_PR_SOURCE_BRANCHES", [] + ) + ignore_pr_target_branches = get_settings().get( + "CONFIG.IGNORE_PR_TARGET_BRANCHES", [] + ) if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches): - if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches): + if any( + re.search(regex, source_branch) for regex in ignore_pr_source_branches + ): get_logger().info( - f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings") + f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings" + ) return False - if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches): + if any( + re.search(regex, target_branch) for regex in ignore_pr_target_branches + ): get_logger().info( - f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings") + f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings" + ) return False except Exception as e: get_logger().error(f"Failed 'should_process_pr_logic': {e}") @@ -308,11 +415,15 @@ async def handle_request(body: Dict[str, Any], event: str): body: The request body. event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.). """ - action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize" + action = body.get( + "action" + ) # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize" if not action: return {} agent = PRAgent() - log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number) + log_context, sender, sender_id, sender_type = get_log_context( + body, event, action, build_number + ) # logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches if is_bot_user(sender, sender_type) and 'check_run' not in body: @@ -327,21 +438,29 @@ async def handle_request(body: Dict[str, Any], event: str): # handle comments on PRs elif action == 'created': get_logger().debug(f'Request body', artifact=body, event=event) - await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent) + await handle_comments_on_pr( + body, event, sender, sender_id, action, log_context, agent + ) # handle new PRs elif event == 'pull_request' and action != 'synchronize' and action != 'closed': get_logger().debug(f'Request body', artifact=body, event=event) - await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent) + await handle_new_pr_opened( + body, event, sender, sender_id, action, log_context, agent + ) elif event == "issue_comment" and 'edited' in action: - pass # handle_checkbox_clicked + pass # handle_checkbox_clicked # handle pull_request event with synchronize action - "push trigger" for new commits elif event == 'pull_request' and action == 'synchronize': - await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent) + await handle_push_trigger_for_new_commits( + body, event, sender, sender_id, action, log_context, agent + ) elif event == 'pull_request' and action == 'closed': if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""): handle_closed_pr(body, event, action, log_context) else: - get_logger().info(f"event {event=} action {action=} does not require any handling") + get_logger().info( + f"event {event=} action {action=} does not require any handling" + ) return {} @@ -362,7 +481,9 @@ def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str: return comment_body -def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]: +def _check_pull_request_event( + action: str, body: dict, log_context: dict +) -> Tuple[Dict[str, Any], str]: invalid_result = {}, "" pull_request = body.get("pull_request") if not pull_request: @@ -373,19 +494,28 @@ def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tup log_context["api_url"] = api_url if pull_request.get("draft", True) or pull_request.get("state") != "open": return invalid_result - if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"): + if action in ("review_requested", "synchronize") and pull_request.get( + "created_at" + ) == pull_request.get("updated_at"): # avoid double reviews when opening a PR for the first time return invalid_result return pull_request, api_url -async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str, - log_context: dict): +async def _perform_auto_commands_github( + commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict +): apply_repo_settings(api_url) - if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled - get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}") + if ( + commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback + ): # auto commands for PR, and auto feedback is disabled + get_logger().info( + f"Auto feedback is disabled, skipping auto commands for PR {api_url=}" + ) return - if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings + if not should_process_pr_logic( + body + ): # Here we already updated the configuration with the repo settings return {} commands = get_settings().get(f"github_app.{commands_conf}") if not commands: @@ -398,7 +528,9 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body args = split_command[1:] other_args = update_settings_from_args(args) new_command = ' '.join([command] + other_args) - get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}") + get_logger().info( + f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}" + ) await agent.handle_request(api_url, new_command) diff --git a/apps/utils/pr_agent/servers/github_polling.py b/apps/utils/pr_agent/servers/github_polling.py index 83c54f9..b06e685 100644 --- a/apps/utils/pr_agent/servers/github_polling.py +++ b/apps/utils/pr_agent/servers/github_polling.py @@ -18,11 +18,13 @@ NOTIFICATION_URL = "https://api.github.com/notifications" async def mark_notification_as_read(headers, notification, session): async with session.patch( - f"https://api.github.com/notifications/threads/{notification['id']}", - headers=headers) as mark_read_response: + f"https://api.github.com/notifications/threads/{notification['id']}", + headers=headers, + ) as mark_read_response: if mark_read_response.status != 205: get_logger().error( - f"Failed to mark notification as read. Status code: {mark_read_response.status}") + f"Failed to mark notification as read. Status code: {mark_read_response.status}" + ) def now() -> str: @@ -36,17 +38,21 @@ def now() -> str: now_utc = now_utc.replace("+00:00", "Z") return now_utc + async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider): agent = PRAgent() success = await agent.handle_request( pr_url, rest_of_comment, - notify=lambda: git_provider.add_eyes_reaction(comment_id) + notify=lambda: git_provider.add_eyes_reaction(comment_id), ) return success + def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider): - return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)) + return asyncio.run( + async_handle_request(pr_url, rest_of_comment, comment_id, git_provider) + ) def process_comment_sync(pr_url, rest_of_comment, comment_id): @@ -55,7 +61,10 @@ def process_comment_sync(pr_url, rest_of_comment, comment_id): git_provider = get_git_provider()(pr_url=pr_url) success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider) except Exception as e: - get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error processing comment: {e}", + artifact={"traceback": traceback.format_exc()}, + ) async def process_comment(pr_url, rest_of_comment, comment_id): @@ -66,22 +75,31 @@ async def process_comment(pr_url, rest_of_comment, comment_id): success = await agent.handle_request( pr_url, rest_of_comment, - notify=lambda: git_provider.add_eyes_reaction(comment_id) + notify=lambda: git_provider.add_eyes_reaction(comment_id), ) get_logger().info(f"Finished processing comment for PR: {pr_url}") except Exception as e: - get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error processing comment: {e}", + artifact={"traceback": traceback.format_exc()}, + ) + async def is_valid_notification(notification, headers, handled_ids, session, user_id): try: if 'reason' in notification and notification['reason'] == 'mention': - if 'subject' in notification and notification['subject']['type'] == 'PullRequest': + if ( + 'subject' in notification + and notification['subject']['type'] == 'PullRequest' + ): pr_url = notification['subject']['url'] latest_comment = notification['subject']['latest_comment_url'] if not latest_comment or not isinstance(latest_comment, str): get_logger().debug(f"no latest_comment") return False, handled_ids - async with session.get(latest_comment, headers=headers) as comment_response: + async with session.get( + latest_comment, headers=headers + ) as comment_response: check_prev_comments = False user_tag = "@" + user_id if comment_response.status == 200: @@ -94,7 +112,9 @@ async def is_valid_notification(notification, headers, handled_ids, session, use handled_ids.add(comment['id']) if 'user' in comment and 'login' in comment['user']: if comment['user']['login'] == user_id: - get_logger().debug(f"comment['user']['login'] == user_id") + get_logger().debug( + f"comment['user']['login'] == user_id" + ) check_prev_comments = True comment_body = comment.get('body', '') if not comment_body: @@ -105,15 +125,28 @@ async def is_valid_notification(notification, headers, handled_ids, session, use get_logger().debug(f"user_tag not in comment_body") check_prev_comments = True else: - get_logger().info(f"Polling, pr_url: {pr_url}", - artifact={"comment": comment_body}) + get_logger().info( + f"Polling, pr_url: {pr_url}", + artifact={"comment": comment_body}, + ) if not check_prev_comments: - return True, handled_ids, comment, comment_body, pr_url, user_tag - else: # we could not find the user tag in the latest comment. Check previous comments + return ( + True, + handled_ids, + comment, + comment_body, + pr_url, + user_tag, + ) + else: # we could not find the user tag in the latest comment. Check previous comments # get all comments in the PR - requests_url = f"{pr_url}/comments".replace("pulls", "issues") - comments_response = requests.get(requests_url, headers=headers) + requests_url = f"{pr_url}/comments".replace( + "pulls", "issues" + ) + comments_response = requests.get( + requests_url, headers=headers + ) comments = comments_response.json()[::-1] max_comment_to_scan = 4 for comment in comments[:max_comment_to_scan]: @@ -124,23 +157,37 @@ async def is_valid_notification(notification, headers, handled_ids, session, use if not comment_body: continue if user_tag in comment_body: - get_logger().info("found user tag in previous comments") - get_logger().info(f"Polling, pr_url: {pr_url}", - artifact={"comment": comment_body}) - return True, handled_ids, comment, comment_body, pr_url, user_tag + get_logger().info( + "found user tag in previous comments" + ) + get_logger().info( + f"Polling, pr_url: {pr_url}", + artifact={"comment": comment_body}, + ) + return ( + True, + handled_ids, + comment, + comment_body, + pr_url, + user_tag, + ) - get_logger().warning(f"Failed to fetch comments for PR: {pr_url}", - artifact={"comments": comments}) + get_logger().warning( + f"Failed to fetch comments for PR: {pr_url}", + artifact={"comments": comments}, + ) return False, handled_ids return False, handled_ids except Exception as e: - get_logger().exception(f"Error processing polling notification", - artifact={"notification": notification, "error": e}) + get_logger().exception( + f"Error processing polling notification", + artifact={"notification": notification, "error": e}, + ) return False, handled_ids - async def polling_loop(): """ Polls for notifications and handles them accordingly. @@ -171,17 +218,17 @@ async def polling_loop(): await asyncio.sleep(5) headers = { "Accept": "application/vnd.github.v3+json", - "Authorization": f"Bearer {token}" - } - params = { - "participating": "true" + "Authorization": f"Bearer {token}", } + params = {"participating": "true"} if since[0]: params["since"] = since[0] if last_modified[0]: headers["If-Modified-Since"] = last_modified[0] - async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response: + async with session.get( + NOTIFICATION_URL, headers=headers, params=params + ) as response: if response.status == 200: if 'Last-Modified' in response.headers: last_modified[0] = response.headers['Last-Modified'] @@ -189,39 +236,67 @@ async def polling_loop(): notifications = await response.json() if not notifications: continue - get_logger().info(f"Received {len(notifications)} notifications") + get_logger().info( + f"Received {len(notifications)} notifications" + ) task_queue = deque() for notification in notifications: if not notification: continue # mark notification as read - await mark_notification_as_read(headers, notification, session) + await mark_notification_as_read( + headers, notification, session + ) handled_ids.add(notification['id']) - output = await is_valid_notification(notification, headers, handled_ids, session, user_id) + output = await is_valid_notification( + notification, headers, handled_ids, session, user_id + ) if output[0]: - _, handled_ids, comment, comment_body, pr_url, user_tag = output - rest_of_comment = comment_body.split(user_tag)[1].strip() + ( + _, + handled_ids, + comment, + comment_body, + pr_url, + user_tag, + ) = output + rest_of_comment = comment_body.split(user_tag)[ + 1 + ].strip() comment_id = comment['id'] # Add to the task queue get_logger().info( - f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}") - task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id))) - get_logger().info(f"Queued comment processing for PR: {pr_url}") + f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}" + ) + task_queue.append( + ( + process_comment_sync, + (pr_url, rest_of_comment, comment_id), + ) + ) + get_logger().info( + f"Queued comment processing for PR: {pr_url}" + ) else: - get_logger().debug(f"Skipping comment processing for PR") + get_logger().debug( + f"Skipping comment processing for PR" + ) max_allowed_parallel_tasks = 10 if task_queue: processes = [] - for i, (func, args) in enumerate(task_queue): # Create parallel tasks + for i, (func, args) in enumerate( + task_queue + ): # Create parallel tasks p = multiprocessing.Process(target=func, args=args) processes.append(p) p.start() if i > max_allowed_parallel_tasks: get_logger().error( - f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session") + f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session" + ) break task_queue.clear() @@ -230,11 +305,15 @@ async def polling_loop(): # p.join() elif response.status != 304: - print(f"Failed to fetch notifications. Status code: {response.status}") + print( + f"Failed to fetch notifications. Status code: {response.status}" + ) except Exception as e: - get_logger().error(f"Polling exception during processing of a notification: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Polling exception during processing of a notification: {e}", + artifact={"traceback": traceback.format_exc()}, + ) if __name__ == '__main__': diff --git a/apps/utils/pr_agent/servers/gitlab_webhook.py b/apps/utils/pr_agent/servers/gitlab_webhook.py index d70efbc..423272d 100644 --- a/apps/utils/pr_agent/servers/gitlab_webhook.py +++ b/apps/utils/pr_agent/servers/gitlab_webhook.py @@ -22,20 +22,21 @@ from utils.pr_agent.secret_providers import get_secret_provider setup_logger(fmt=LoggingFormat.JSON, level="DEBUG") router = APIRouter() -secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None +secret_provider = ( + get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None +) async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id): try: import requests - headers = { - 'Private-Token': f'{gitlab_token}' - } + + headers = {'Private-Token': f'{gitlab_token}'} # API endpoint to find MRs containing the commit gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com') response = requests.get( f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests', - headers=headers + headers=headers, ) merge_requests = response.json() if merge_requests and response.status_code == 200: @@ -48,6 +49,7 @@ async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id): get_logger().error(f"Failed to get MR url from commit sha: {e}") return None + async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str): log_context["action"] = body log_context["event"] = "pull_request" if body == "/review" else "comment" @@ -58,13 +60,19 @@ async def handle_request(api_url: str, body: str, log_context: dict, sender_id: await PRAgent().handle_request(api_url, body) -async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str, - log_context: dict, data: dict): +async def _perform_commands_gitlab( + commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict +): apply_repo_settings(api_url) - if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled - get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context) + if ( + commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback + ): # auto commands for PR, and auto feedback is disabled + get_logger().info( + f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", + **log_context, + ) return - if not should_process_pr_logic(data): # Here we already updated the configurations + if not should_process_pr_logic(data): # Here we already updated the configurations return commands = get_settings().get(f"gitlab.{commands_conf}", {}) get_settings().set("config.is_auto_command", True) @@ -106,40 +114,58 @@ def should_process_pr_logic(data) -> bool: ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) if ignore_pr_users and sender: if sender in ignore_pr_users: - get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings") + get_logger().info( + f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings" + ) return False # logic to ignore MRs for titles, labels and source, target branches. ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", []) - ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) - ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) + ignore_mr_source_branches = get_settings().get( + "CONFIG.IGNORE_PR_SOURCE_BRANCHES", [] + ) + ignore_mr_target_branches = get_settings().get( + "CONFIG.IGNORE_PR_TARGET_BRANCHES", [] + ) # if ignore_mr_source_branches: source_branch = data['object_attributes'].get('source_branch') - if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches): + if any( + re.search(regex, source_branch) for regex in ignore_mr_source_branches + ): get_logger().info( - f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings") + f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings" + ) return False if ignore_mr_target_branches: target_branch = data['object_attributes'].get('target_branch') - if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches): + if any( + re.search(regex, target_branch) for regex in ignore_mr_target_branches + ): get_logger().info( - f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings") + f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings" + ) return False if ignore_mr_labels: - labels = [label['title'] for label in data['object_attributes'].get('labels', [])] + labels = [ + label['title'] for label in data['object_attributes'].get('labels', []) + ] if any(label in ignore_mr_labels for label in labels): labels_str = ", ".join(labels) - get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings") + get_logger().info( + f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings" + ) return False if ignore_mr_title: if any(re.search(regex, title) for regex in ignore_mr_title): - get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings") + get_logger().info( + f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings" + ) return False except Exception as e: get_logger().error(f"Failed 'should_process_pr_logic': {e}") @@ -159,29 +185,47 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request): request_token = request.headers.get("X-Gitlab-Token") secret = secret_provider.get_secret(request_token) if not secret: - get_logger().warning(f"Empty secret retrieved, request_token: {request_token}") - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, - content=jsonable_encoder({"message": "unauthorized"})) + get_logger().warning( + f"Empty secret retrieved, request_token: {request_token}" + ) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=jsonable_encoder({"message": "unauthorized"}), + ) try: secret_dict = json.loads(secret) gitlab_token = secret_dict["gitlab_token"] - log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown")) + log_context["token_id"] = secret_dict.get( + "token_name", secret_dict.get("id", "unknown") + ) context["settings"].gitlab.personal_access_token = gitlab_token except Exception as e: get_logger().error(f"Failed to validate secret {request_token}: {e}") - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=jsonable_encoder({"message": "unauthorized"}), + ) elif get_settings().get("GITLAB.SHARED_SECRET"): secret = get_settings().get("GITLAB.SHARED_SECRET") if not request.headers.get("X-Gitlab-Token") == secret: get_logger().error("Failed to validate secret") - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=jsonable_encoder({"message": "unauthorized"}), + ) else: get_logger().error("Failed to validate secret") - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=jsonable_encoder({"message": "unauthorized"}), + ) gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) if not gitlab_token: get_logger().error("No gitlab token found") - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=jsonable_encoder({"message": "unauthorized"}), + ) get_logger().info("GitLab data", artifact=data) sender = data.get("user", {}).get("username", "unknown") @@ -189,31 +233,49 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request): # ignore bot users if is_bot_user(data): - return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) - if data.get('event_type') != 'note': # not a comment + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "success"}), + ) + if data.get('event_type') != 'note': # not a comment # ignore MRs based on title, labels, source and target branches if not should_process_pr_logic(data): - return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "success"}), + ) log_context["sender"] = sender - if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']: + if data.get('object_kind') == 'merge_request' and data['object_attributes'].get( + 'action' + ) in ['open', 'reopen']: title = data['object_attributes'].get('title') url = data['object_attributes'].get('url') draft = data['object_attributes'].get('draft') get_logger().info(f"New merge request: {url}") if draft: get_logger().info(f"Skipping draft MR: {url}") - return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "success"}), + ) - await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data) - elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR + await _perform_commands_gitlab( + "pr_commands", PRAgent(), url, log_context, data + ) + elif ( + data.get('object_kind') == 'note' and data.get('event_type') == 'note' + ): # comment on MR if 'merge_request' in data: mr = data['merge_request'] url = mr.get('url') get_logger().info(f"A comment has been added to a merge request: {url}") body = data.get('object_attributes', {}).get('note') - if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line + if ( + data.get('object_attributes', {}).get('type') == 'DiffNote' + and '/ask' in body + ): # /ask_line body = handle_ask_line(body, data) await handle_request(url, body, log_context, sender_id) @@ -221,30 +283,44 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request): try: project_id = data['project_id'] commit_sha = data['checkout_sha'] - url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id) + url = await get_mr_url_from_commit_sha( + commit_sha, gitlab_token, project_id + ) if not url: get_logger().info(f"No MR found for commit: {commit_sha}") - return JSONResponse(status_code=status.HTTP_200_OK, - content=jsonable_encoder({"message": "success"})) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "success"}), + ) # we need first to apply_repo_settings apply_repo_settings(url) commands_on_push = get_settings().get(f"gitlab.push_commands", {}) - handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False) + handle_push_trigger = get_settings().get( + f"gitlab.handle_push_trigger", False + ) if not commands_on_push or not handle_push_trigger: - get_logger().info("Push event, but no push commands found or push trigger is disabled") - return JSONResponse(status_code=status.HTTP_200_OK, - content=jsonable_encoder({"message": "success"})) + get_logger().info( + "Push event, but no push commands found or push trigger is disabled" + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder({"message": "success"}), + ) get_logger().debug(f'A push event has been received: {url}') - await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data) + await _perform_commands_gitlab( + "push_commands", PRAgent(), url, log_context, data + ) except Exception as e: get_logger().error(f"Failed to handle push event: {e}") background_tasks.add_task(inner, request_json) end_time = datetime.now() get_logger().info(f"Processing time: {end_time - start_time}", request=request_json) - return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) + return JSONResponse( + status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}) + ) def handle_ask_line(body, data): @@ -271,6 +347,7 @@ def handle_ask_line(body, data): async def root(): return {"status": "ok"} + gitlab_url = get_settings().get("GITLAB.URL", None) if not gitlab_url: raise ValueError("GITLAB.URL is not set") diff --git a/apps/utils/pr_agent/servers/help.py b/apps/utils/pr_agent/servers/help.py index 7edd13d..0af25dd 100644 --- a/apps/utils/pr_agent/servers/help.py +++ b/apps/utils/pr_agent/servers/help.py @@ -1,18 +1,19 @@ class HelpMessage: @staticmethod def get_general_commands_text(): - commands_text = "> - **/review**: Request a review of your Pull Request. \n" \ - "> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \ - "> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \ - "> - **/ask \\**: Ask a question about the PR. \n" \ - "> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \ - "> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \ - "> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \ - "> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \ - ">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \ - ">To list the possible configuration parameters, add a **/config** comment. \n" - return commands_text - + commands_text = ( + "> - **/review**: Request a review of your Pull Request. \n" + "> - **/describe**: Update the PR title and description based on the contents of the PR. \n" + "> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" + "> - **/ask \\**: Ask a question about the PR. \n" + "> - **/update_changelog**: Update the changelog based on the PR's contents. \n" + "> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" + "> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" + "> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" + ">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" + ">To list the possible configuration parameters, add a **/config** comment. \n" + ) + return commands_text @staticmethod def get_general_bot_help_text(): @@ -21,10 +22,12 @@ class HelpMessage: @staticmethod def get_review_usage_guide(): - output ="**Overview:**\n" - output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n" - "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n") - output +="""\ + output = "**Overview:**\n" + output += ( + "The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n" + "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n" + ) + output += """\ - When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template: ``` /review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=... @@ -41,8 +44,6 @@ some_config2=... return output - - @staticmethod def get_describe_usage_guide(): output = "**Overview:**\n" @@ -137,7 +138,6 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th ''' output += "\n\n
\n\n" - # general output += "\n\n
More PR-Agent commands
\n\n" output += HelpMessage.get_general_bot_help_text() @@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a return output - @staticmethod def get_improve_usage_guide(): output = "**Overview:**\n" diff --git a/apps/utils/pr_agent/servers/utils.py b/apps/utils/pr_agent/servers/utils.py index 4b1ea80..4b3c788 100644 --- a/apps/utils/pr_agent/servers/utils.py +++ b/apps/utils/pr_agent/servers/utils.py @@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header): signature_header: header received from GitHub (x-hub-signature-256) """ if not signature_header: - raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!") - hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256) + raise HTTPException( + status_code=403, detail="x-hub-signature-256 header is missing!" + ) + hash_object = hmac.new( + secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256 + ) expected_signature = "sha256=" + hash_object.hexdigest() if not hmac.compare_digest(expected_signature, signature_header): raise HTTPException(status_code=403, detail="Request signatures didn't match!") @@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header): class RateLimitExceeded(Exception): """Raised when the git provider API rate limit has been exceeded.""" + pass @@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict): request_time = self.__time() if request_time - self.__last_refresh > self.__refresh_interval: return - to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl] + to_delete = [ + key + for key, key_time in self.__key_times.items() + if request_time - key_time > self.__ttl + ] for key in to_delete: del self[key] self.__last_refresh = request_time diff --git a/apps/utils/pr_agent/tools/pr_add_docs.py b/apps/utils/pr_agent/tools/pr_add_docs.py index 362e8b5..be59838 100644 --- a/apps/utils/pr_agent/tools/pr_add_docs.py +++ b/apps/utils/pr_agent/tools/pr_add_docs.py @@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): - + def __init__( + self, + pr_url: str, + cli_mode=False, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() @@ -39,13 +43,16 @@ class PRAddDocs: "diff": "", # empty diff for initial calculation "extra_instructions": get_settings().pr_add_docs.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), - 'docs_for_language': get_docs_for_language(self.main_language, - get_settings().pr_add_docs.docs_style), + 'docs_for_language': get_docs_for_language( + self.main_language, get_settings().pr_add_docs.docs_style + ), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_add_docs_prompt.system, - get_settings().pr_add_docs_prompt.user) + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_add_docs_prompt.system, + get_settings().pr_add_docs_prompt.user, + ) async def run(self): try: @@ -66,16 +73,20 @@ class PRAddDocs: get_logger().info('Pushing inline code documentation...') self.push_inline_docs(data) except Exception as e: - get_logger().error(f"Failed to generate code documentation for PR, error: {e}") + get_logger().error( + f"Failed to generate code documentation for PR, error: {e}" + ) async def _prepare_prediction(self, model: str): get_logger().info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - model, - add_line_numbers_to_hunks=True, - disable_extra_lines=False) + self.patches_diff = get_pr_diff( + self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=False, + ) get_logger().info('Getting AI prediction...') self.prediction = await self._get_prediction(model) @@ -84,13 +95,21 @@ class PRAddDocs: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_add_docs_prompt.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_add_docs_prompt.user + ).render(variables) if get_settings().config.verbosity_level >= 2: get_logger().info(f"\nSystem prompt:\n{system_prompt}") get_logger().info(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + ) return response @@ -105,7 +124,9 @@ class PRAddDocs: docs = [] if not data['Code Documentation']: - return self.git_provider.publish_comment('No code documentation found to improve this PR.') + return self.git_provider.publish_comment( + 'No code documentation found to improve this PR.' + ) for d in data['Code Documentation']: try: @@ -116,32 +137,59 @@ class PRAddDocs: documentation = d['documentation'] doc_placement = d['doc placement'].strip() if documentation: - new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement, - add_original_line=True) + new_code_snippet = self.dedent_code( + relevant_file, + relevant_line, + documentation, + doc_placement, + add_original_line=True, + ) - body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```" - docs.append({'body': body, 'relevant_file': relevant_file, - 'relevant_lines_start': relevant_line, - 'relevant_lines_end': relevant_line}) + body = ( + f"**Suggestion:** Proposed documentation\n```suggestion\n" + + new_code_snippet + + "\n```" + ) + docs.append( + { + 'body': body, + 'relevant_file': relevant_file, + 'relevant_lines_start': relevant_line, + 'relevant_lines_end': relevant_line, + } + ) except Exception: if get_settings().config.verbosity_level >= 2: get_logger().info(f"Could not parse code docs: {d}") is_successful = self.git_provider.publish_code_suggestions(docs) if not is_successful: - get_logger().info("Failed to publish code docs, trying to publish each docs separately") + get_logger().info( + "Failed to publish code docs, trying to publish each docs separately" + ) for doc_suggestion in docs: self.git_provider.publish_code_suggestions([doc_suggestion]) - def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after', - add_original_line=False): + def dedent_code( + self, + relevant_file, + relevant_lines_start, + new_code_snippet, + doc_placement='after', + add_original_line=False, + ): try: # dedent code snippet - self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \ + self.diff_files = ( + self.git_provider.diff_files + if self.git_provider.diff_files else self.git_provider.get_diff_files() + ) original_initial_line = None for file in self.diff_files: if file.filename.strip() == relevant_file: - original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1] + original_initial_line = file.head_file.splitlines()[ + relevant_lines_start - 1 + ] break if original_initial_line: if doc_placement == 'after': @@ -150,18 +198,28 @@ class PRAddDocs: line = original_initial_line suggested_initial_line = new_code_snippet.splitlines()[0] original_initial_spaces = len(line) - len(line.lstrip()) - suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip()) + suggested_initial_spaces = len(suggested_initial_line) - len( + suggested_initial_line.lstrip() + ) delta_spaces = original_initial_spaces - suggested_initial_spaces if delta_spaces > 0: - new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') + new_code_snippet = textwrap.indent( + new_code_snippet, delta_spaces * " " + ).rstrip('\n') if add_original_line: if doc_placement == 'after': - new_code_snippet = original_initial_line + "\n" + new_code_snippet + new_code_snippet = ( + original_initial_line + "\n" + new_code_snippet + ) else: - new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line + new_code_snippet = ( + new_code_snippet.rstrip() + "\n" + original_initial_line + ) except Exception as e: if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") + get_logger().info( + f"Could not dedent code snippet for file {relevant_file}, error: {e}" + ) return new_code_snippet diff --git a/apps/utils/pr_agent/tools/pr_code_suggestions.py b/apps/utils/pr_agent/tools/pr_code_suggestions.py index b0dd5c7..9fce273 100644 --- a/apps/utils/pr_agent/tools/pr_code_suggestions.py +++ b/apps/utils/pr_agent/tools/pr_code_suggestions.py @@ -12,15 +12,25 @@ from jinja2 import Environment, StrictUndefined from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler -from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, - get_pr_diff, get_pr_multi_diffs, - retry_with_fallback_models) +from utils.pr_agent.algo.pr_processing import ( + add_ai_metadata_to_diff_files, + get_pr_diff, + get_pr_multi_diffs, + retry_with_fallback_models, +) from utils.pr_agent.algo.token_handler import TokenHandler -from utils.pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags, - show_relevant_configurations) +from utils.pr_agent.algo.utils import ( + ModelType, + load_yaml, + replace_code_tags, + show_relevant_configurations, +) from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.git_providers import (AzureDevopsProvider, GithubProvider, - get_git_provider_with_context) +from utils.pr_agent.git_providers import ( + AzureDevopsProvider, + GithubProvider, + get_git_provider_with_context, +) from utils.pr_agent.git_providers.git_provider import get_main_pr_language, GitProvider from utils.pr_agent.log import get_logger from utils.pr_agent.servers.help import HelpMessage @@ -28,9 +38,13 @@ from utils.pr_agent.tools.pr_description import insert_br_after_x_chars class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): - + def __init__( + self, + pr_url: str, + cli_mode=False, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): self.git_provider = get_git_provider_with_context(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() @@ -38,10 +52,16 @@ class PRCodeSuggestions: # limit context specifically for the improve command, which has hard input to parse: if get_settings().pr_code_suggestions.max_context_tokens: - MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens + MAX_CONTEXT_TOKENS_IMPROVE = ( + get_settings().pr_code_suggestions.max_context_tokens + ) if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE: - get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve") - get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens + get_logger().info( + f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve" + ) + get_settings().config.max_model_tokens_original = ( + get_settings().config.max_model_tokens + ) get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE # extended mode @@ -49,8 +69,9 @@ class PRCodeSuggestions: self.is_extended = self._get_is_extended(args or []) except: self.is_extended = False - num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk) - + num_code_suggestions = int( + get_settings().pr_code_suggestions.num_code_suggestions_per_chunk + ) self.ai_handler = ai_handler() self.ai_handler.main_pr_language = self.main_language @@ -58,10 +79,15 @@ class PRCodeSuggestions: self.prediction = None self.pr_url = pr_url self.cli_mode = cli_mode - self.pr_description, self.pr_description_files = ( - self.git_provider.get_pr_description(split_changes_walkthrough=True)) - if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and - get_settings().get("config.enable_ai_metadata", False)): + ( + self.pr_description, + self.pr_description_files, + ) = self.git_provider.get_pr_description(split_changes_walkthrough=True) + if ( + self.pr_description_files + and get_settings().get("config.is_auto_command", False) + and get_settings().get("config.enable_ai_metadata", False) + ): add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files) get_logger().debug(f"AI metadata added to the this command") else: @@ -80,16 +106,24 @@ class PRCodeSuggestions: "commit_messages_str": self.git_provider.get_commit_messages(), "relevant_best_practices": "", "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), - "focus_only_on_problems": get_settings().get("pr_code_suggestions.focus_only_on_problems", False), + "focus_only_on_problems": get_settings().get( + "pr_code_suggestions.focus_only_on_problems", False + ), "date": datetime.now().strftime('%Y-%m-%d'), - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), + 'duplicate_prompt_examples': get_settings().config.get( + 'duplicate_prompt_examples', False + ), } - self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system + self.pr_code_suggestions_prompt_system = ( + get_settings().pr_code_suggestions_prompt.system + ) - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - self.pr_code_suggestions_prompt_system, - get_settings().pr_code_suggestions_prompt.user) + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + self.pr_code_suggestions_prompt_system, + get_settings().pr_code_suggestions_prompt.user, + ) self.progress = f"## 生成 PR 代码建议\n\n" self.progress += f"""\n思考中 ...
\n""" @@ -98,33 +132,50 @@ class PRCodeSuggestions: async def run(self): try: if not self.git_provider.get_files(): - get_logger().info(f"PR has no files: {self.pr_url}, skipping code suggestions") + get_logger().info( + f"PR has no files: {self.pr_url}, skipping code suggestions" + ) return None get_logger().info('Generating code suggestions for PR...') - relevant_configs = {'pr_code_suggestions': dict(get_settings().pr_code_suggestions), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_code_suggestions': dict(get_settings().pr_code_suggestions), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) # publish "Preparing suggestions..." comments - if (get_settings().config.publish_output and get_settings().config.publish_output_progress and - not get_settings().config.get('is_auto_command', False)): + if ( + get_settings().config.publish_output + and get_settings().config.publish_output_progress + and not get_settings().config.get('is_auto_command', False) + ): if self.git_provider.is_supported("gfm_markdown"): - self.progress_response = self.git_provider.publish_comment(self.progress) + self.progress_response = self.git_provider.publish_comment( + self.progress + ) else: self.git_provider.publish_comment("准备建议中...", is_temporary=True) # call the model to get the suggestions, and self-reflect on them if not self.is_extended: - data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) + data = await retry_with_fallback_models( + self._prepare_prediction, model_type=ModelType.REGULAR + ) else: - data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR) + data = await retry_with_fallback_models( + self._prepare_prediction_extended, model_type=ModelType.REGULAR + ) if not data: data = {"code_suggestions": []} self.data = data # Handle the case where the PR has no suggestions - if (data is None or 'code_suggestions' not in data or not data['code_suggestions']): + if ( + data is None + or 'code_suggestions' not in data + or not data['code_suggestions'] + ): await self.publish_no_suggestions() return @@ -134,20 +185,25 @@ class PRCodeSuggestions: self.git_provider.remove_initial_comment() # Publish table summarized suggestions - if ((not get_settings().pr_code_suggestions.commitable_code_suggestions) and - self.git_provider.is_supported("gfm_markdown")): - + if ( + not get_settings().pr_code_suggestions.commitable_code_suggestions + ) and self.git_provider.is_supported("gfm_markdown"): # generate summarized suggestions pr_body = self.generate_summarized_suggestions(data) get_logger().debug(f"PR output", artifact=pr_body) # require self-review - if get_settings().pr_code_suggestions.demand_code_suggestions_self_review: + if ( + get_settings().pr_code_suggestions.demand_code_suggestions_self_review + ): pr_body = await self.add_self_review_text(pr_body) # add usage guide - if (get_settings().pr_code_suggestions.enable_chat_text and get_settings().config.is_auto_command - and isinstance(self.git_provider, GithubProvider)): + if ( + get_settings().pr_code_suggestions.enable_chat_text + and get_settings().config.is_auto_command + and isinstance(self.git_provider, GithubProvider) + ): pr_body += "\n\n>💡 Need additional feedback ? start a [PR chat](https://chromewebstore.google.com/detail/ephlnjeghhogofkifjloamocljapahnl) \n\n" if get_settings().pr_code_suggestions.enable_help_text: pr_body += "
\n\n
💡 Tool usage guide:
\n\n" @@ -155,55 +211,84 @@ class PRCodeSuggestions: pr_body += "\n
\n" # Output the relevant configurations if enabled - if get_settings().get('config', {}).get('output_relevant_configurations', False): - pr_body += show_relevant_configurations(relevant_section='pr_code_suggestions') + if ( + get_settings() + .get('config', {}) + .get('output_relevant_configurations', False) + ): + pr_body += show_relevant_configurations( + relevant_section='pr_code_suggestions' + ) # publish the PR comment - if get_settings().pr_code_suggestions.persistent_comment: # true by default - self.publish_persistent_comment_with_history(self.git_provider, - pr_body, - initial_header="## PR 代码建议 ✨", - update_header=True, - name="suggestions", - final_update_message=False, - max_previous_comments=get_settings().pr_code_suggestions.max_history_len, - progress_response=self.progress_response) + if ( + get_settings().pr_code_suggestions.persistent_comment + ): # true by default + self.publish_persistent_comment_with_history( + self.git_provider, + pr_body, + initial_header="## PR 代码建议 ✨", + update_header=True, + name="suggestions", + final_update_message=False, + max_previous_comments=get_settings().pr_code_suggestions.max_history_len, + progress_response=self.progress_response, + ) else: if self.progress_response: - self.git_provider.edit_comment(self.progress_response, body=pr_body) + self.git_provider.edit_comment( + self.progress_response, body=pr_body + ) else: self.git_provider.publish_comment(pr_body) # dual publishing mode - if int(get_settings().pr_code_suggestions.dual_publishing_score_threshold) > 0: + if ( + int( + get_settings().pr_code_suggestions.dual_publishing_score_threshold + ) + > 0 + ): await self.dual_publishing(data) else: await self.push_inline_code_suggestions(data) if self.progress_response: self.git_provider.remove_comment(self.progress_response) else: - get_logger().info('Code suggestions generated for PR, but not published since publish_output is False.') + get_logger().info( + 'Code suggestions generated for PR, but not published since publish_output is False.' + ) pr_body = self.generate_summarized_suggestions(data) get_settings().data = {"artifact": pr_body} return except Exception as e: - get_logger().error(f"Failed to generate code suggestions for PR, error: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failed to generate code suggestions for PR, error: {e}", + artifact={"traceback": traceback.format_exc()}, + ) if get_settings().config.publish_output: if self.progress_response: self.progress_response.delete() else: try: self.git_provider.remove_initial_comment() - self.git_provider.publish_comment(f"Failed to generate code suggestions for PR") + self.git_provider.publish_comment( + f"Failed to generate code suggestions for PR" + ) except Exception as e: - get_logger().exception(f"Failed to update persistent review, error: {e}") + get_logger().exception( + f"Failed to update persistent review, error: {e}" + ) async def add_self_review_text(self, pr_body): text = get_settings().pr_code_suggestions.code_suggestions_self_review_text pr_body += f"\n\n- [ ] {text}" - approve_pr_on_self_review = get_settings().pr_code_suggestions.approve_pr_on_self_review - fold_suggestions_on_self_review = get_settings().pr_code_suggestions.fold_suggestions_on_self_review + approve_pr_on_self_review = ( + get_settings().pr_code_suggestions.approve_pr_on_self_review + ) + fold_suggestions_on_self_review = ( + get_settings().pr_code_suggestions.fold_suggestions_on_self_review + ) if approve_pr_on_self_review and not fold_suggestions_on_self_review: pr_body += ' ' elif fold_suggestions_on_self_review and not approve_pr_on_self_review: @@ -214,7 +299,10 @@ class PRCodeSuggestions: async def publish_no_suggestions(self): pr_body = "## PR 代码建议 ✨\n\n未找到该PR的代码建议." - if get_settings().config.publish_output and get_settings().config.publish_output_no_suggestions: + if ( + get_settings().config.publish_output + and get_settings().config.publish_output_no_suggestions + ): get_logger().warning('No code suggestions found for the PR.') get_logger().debug(f"PR output", artifact=pr_body) if self.progress_response: @@ -229,31 +317,40 @@ class PRCodeSuggestions: try: for suggestion in data['code_suggestions']: if int(suggestion.get('score', 0)) >= int( - get_settings().pr_code_suggestions.dual_publishing_score_threshold) \ - and suggestion.get('improved_code'): + get_settings().pr_code_suggestions.dual_publishing_score_threshold + ) and suggestion.get('improved_code'): data_above_threshold['code_suggestions'].append(suggestion) - if not data_above_threshold['code_suggestions'][-1]['existing_code']: - get_logger().info(f'Identical existing and improved code for dual publishing found') - data_above_threshold['code_suggestions'][-1]['existing_code'] = suggestion[ - 'improved_code'] + if not data_above_threshold['code_suggestions'][-1][ + 'existing_code' + ]: + get_logger().info( + f'Identical existing and improved code for dual publishing found' + ) + data_above_threshold['code_suggestions'][-1][ + 'existing_code' + ] = suggestion['improved_code'] if data_above_threshold['code_suggestions']: get_logger().info( - f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode") + f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode" + ) await self.push_inline_code_suggestions(data_above_threshold) except Exception as e: - get_logger().error(f"Failed to publish dual publishing suggestions, error: {e}") + get_logger().error( + f"Failed to publish dual publishing suggestions, error: {e}" + ) @staticmethod - def publish_persistent_comment_with_history(git_provider: GitProvider, - pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True, - max_previous_comments=4, - progress_response=None, - only_fold=False): - + def publish_persistent_comment_with_history( + git_provider: GitProvider, + pr_comment: str, + initial_header: str, + update_header: bool = True, + name='review', + final_update_message=True, + max_previous_comments=4, + progress_response=None, + only_fold=False, + ): def _extract_link(comment_text: str): r = re.compile(r"") match = r.search(comment_text) @@ -263,7 +360,9 @@ class PRCodeSuggestions: up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}" return up_to_commit_txt - if isinstance(git_provider, AzureDevopsProvider): # get_latest_commit_url is not supported yet + if isinstance( + git_provider, AzureDevopsProvider + ): # get_latest_commit_url is not supported yet if progress_response: git_provider.edit_comment(progress_response, pr_comment) new_comment = progress_response @@ -273,7 +372,7 @@ class PRCodeSuggestions: history_header = f"#### Previous suggestions\n" last_commit_num = git_provider.get_latest_commit_url().split('/')[-1][:7] - if only_fold: # A user clicked on the 'self-review' checkbox + if only_fold: # A user clicked on the 'self-review' checkbox text = get_settings().pr_code_suggestions.code_suggestions_self_review_text latest_suggestion_header = f"\n\n- [x] {text}" else: @@ -300,42 +399,66 @@ class PRCodeSuggestions: # find http link from comment.body[:table_index] up_to_commit_txt = _extract_link(comment.body[:table_index]) prev_suggestion_table = comment.body[ - table_index:comment.body.rfind("") + len("")] + table_index : comment.body.rfind("") + + len("") + ] tick = "✅ " if "✅" in prev_suggestion_table else "" # surround with details tag prev_suggestion_table = f"
{tick}{name.capitalize()}{up_to_commit_txt}\n
{prev_suggestion_table}\n\n
" - new_suggestion_table = pr_comment.replace(initial_header, "").strip() + new_suggestion_table = pr_comment.replace( + initial_header, "" + ).strip() - pr_comment_updated = f"{initial_header}\n{latest_commit_html_comment}\n\n" + pr_comment_updated = ( + f"{initial_header}\n{latest_commit_html_comment}\n\n" + ) pr_comment_updated += f"{latest_suggestion_header}\n{new_suggestion_table}\n\n___\n\n" - pr_comment_updated += f"{history_header}{prev_suggestion_table}\n" + pr_comment_updated += ( + f"{history_header}{prev_suggestion_table}\n" + ) else: # get the text of the previous suggestions until the latest commit sections = prev_suggestions.split(history_header.strip()) latest_table = sections[0].strip() - prev_suggestion_table = sections[1].replace(history_header, "").strip() + prev_suggestion_table = ( + sections[1].replace(history_header, "").strip() + ) # get text after the latest_suggestion_header in comment.body table_ind = latest_table.find("") up_to_commit_txt = _extract_link(latest_table[:table_ind]) - latest_table = latest_table[table_ind:latest_table.rfind("
") + len("")] + latest_table = latest_table[ + table_ind : latest_table.rfind("") + + len("") + ] # enforce max_previous_comments - count = prev_suggestions.count(f"\n
{name.capitalize()}") - count += prev_suggestions.count(f"\n
✅ {name.capitalize()}") + count = prev_suggestions.count( + f"\n
{name.capitalize()}" + ) + count += prev_suggestions.count( + f"\n
✅ {name.capitalize()}" + ) if count >= max_previous_comments: # remove the oldest suggestion - prev_suggestion_table = prev_suggestion_table[:prev_suggestion_table.rfind( - f"
{name.capitalize()} up to commit")] + prev_suggestion_table = prev_suggestion_table[ + : prev_suggestion_table.rfind( + f"
{name.capitalize()} up to commit" + ) + ] tick = "✅ " if "✅" in latest_table else "" # Add to the prev_suggestions section last_prev_table = f"\n
{tick}{name.capitalize()}{up_to_commit_txt}\n
{latest_table}\n\n
" - prev_suggestion_table = last_prev_table + "\n" + prev_suggestion_table + prev_suggestion_table = ( + last_prev_table + "\n" + prev_suggestion_table + ) - new_suggestion_table = pr_comment.replace(initial_header, "").strip() + new_suggestion_table = pr_comment.replace( + initial_header, "" + ).strip() pr_comment_updated = f"{initial_header}\n" pr_comment_updated += f"{latest_commit_html_comment}\n\n" @@ -344,16 +467,24 @@ class PRCodeSuggestions: pr_comment_updated += f"{history_header}\n" pr_comment_updated += f"{prev_suggestion_table}\n" - get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message") - if progress_response: # publish to 'progress_response' comment, because it refreshes immediately - git_provider.edit_comment(progress_response, pr_comment_updated) + get_logger().info( + f"Persistent mode - updating comment {comment_url} to latest {name} message" + ) + if ( + progress_response + ): # publish to 'progress_response' comment, because it refreshes immediately + git_provider.edit_comment( + progress_response, pr_comment_updated + ) git_provider.remove_comment(comment) comment = progress_response else: git_provider.edit_comment(comment, pr_comment_updated) return comment except Exception as e: - get_logger().exception(f"Failed to update persistent review, error: {e}") + get_logger().exception( + f"Failed to update persistent review, error: {e}" + ) pass # if we are here, we did not find a previous comment to update @@ -366,7 +497,6 @@ class PRCodeSuggestions: new_comment = git_provider.publish_comment(pr_comment) return new_comment - def extract_link(self, s): r = re.compile(r"") match = r.search(s) @@ -377,17 +507,23 @@ class PRCodeSuggestions: return up_to_commit_txt async def _prepare_prediction(self, model: str) -> dict: - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - model, - add_line_numbers_to_hunks=True, - disable_extra_lines=False) + self.patches_diff = get_pr_diff( + self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=False, + ) self.patches_diff_list = [self.patches_diff] - self.patches_diff_no_line_number = self.remove_line_numbers([self.patches_diff])[0] + self.patches_diff_no_line_number = self.remove_line_numbers( + [self.patches_diff] + )[0] if self.patches_diff: get_logger().debug(f"PR diff", artifact=self.patches_diff) - self.prediction = await self._get_prediction(model, self.patches_diff, self.patches_diff_no_line_number) + self.prediction = await self._get_prediction( + model, self.patches_diff, self.patches_diff_no_line_number + ) else: get_logger().warning(f"Empty PR diff") self.prediction = None @@ -395,15 +531,25 @@ class PRCodeSuggestions: data = self.prediction return data - async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_line_number: str) -> dict: + async def _get_prediction( + self, model: str, patches_diff: str, patches_diff_no_line_number: str + ) -> dict: variables = copy.deepcopy(self.vars) variables["diff"] = patches_diff # update diff variables["diff_no_line_numbers"] = patches_diff_no_line_number # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables) - user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) + system_prompt = environment.from_string( + self.pr_code_suggestions_prompt_system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_code_suggestions_prompt.user + ).render(variables) response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + ) if not get_settings().config.publish_output: get_settings().system_prompt = system_prompt get_settings().user_prompt = user_prompt @@ -413,8 +559,9 @@ class PRCodeSuggestions: # self-reflect on suggestions (mandatory, since line numbers are generated now here) model_reflection = get_settings().config.model - response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], - patches_diff, model=model_reflection) + response_reflect = await self.self_reflect_on_suggestions( + data["code_suggestions"], patches_diff, model=model_reflection + ) if response_reflect: await self.analyze_self_reflection_response(data, response_reflect) else: @@ -428,15 +575,23 @@ class PRCodeSuggestions: async def analyze_self_reflection_response(self, data, response_reflect): response_reflect_yaml = load_yaml(response_reflect) code_suggestions_feedback = response_reflect_yaml.get("code_suggestions", []) - if code_suggestions_feedback and len(code_suggestions_feedback) == len(data["code_suggestions"]): + if code_suggestions_feedback and len(code_suggestions_feedback) == len( + data["code_suggestions"] + ): for i, suggestion in enumerate(data["code_suggestions"]): try: - suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"] + suggestion["score"] = code_suggestions_feedback[i][ + "suggestion_score" + ] suggestion["score_why"] = code_suggestions_feedback[i]["why"] if 'relevant_lines_start' not in suggestion: - relevant_lines_start = code_suggestions_feedback[i].get('relevant_lines_start', -1) - relevant_lines_end = code_suggestions_feedback[i].get('relevant_lines_end', -1) + relevant_lines_start = code_suggestions_feedback[i].get( + 'relevant_lines_start', -1 + ) + relevant_lines_end = code_suggestions_feedback[i].get( + 'relevant_lines_end', -1 + ) suggestion['relevant_lines_start'] = relevant_lines_start suggestion['relevant_lines_end'] = relevant_lines_end if relevant_lines_start < 0 or relevant_lines_end < 0: @@ -450,18 +605,29 @@ class PRCodeSuggestions: score = int(suggestion["score"]) label = suggestion["label"].lower().strip() label = label.replace('
', ' ') - suggestion_statistics_dict = {'score': score, - 'label': label} - get_logger().info(f"PR-Agent suggestions statistics", - statistics=suggestion_statistics_dict, analytics=True) + suggestion_statistics_dict = { + 'score': score, + 'label': label, + } + get_logger().info( + f"PR-Agent suggestions statistics", + statistics=suggestion_statistics_dict, + analytics=True, + ) except Exception as e: - get_logger().error(f"Failed to log suggestion statistics, error: {e}") + get_logger().error( + f"Failed to log suggestion statistics, error: {e}" + ) pass except Exception as e: # - get_logger().error(f"Error processing suggestion score {i}", - artifact={"suggestion": suggestion, - "code_suggestions_feedback": code_suggestions_feedback[i]}) + get_logger().error( + f"Error processing suggestion score {i}", + artifact={ + "suggestion": suggestion, + "code_suggestions_feedback": code_suggestions_feedback[i], + }, + ) suggestion["score"] = 7 suggestion["score_why"] = "" @@ -469,30 +635,53 @@ class PRCodeSuggestions: try: if suggestion['existing_code'] == suggestion['improved_code']: get_logger().debug( - f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}") - if get_settings().pr_code_suggestions.commitable_code_suggestions: - suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR + f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}" + ) + if ( + get_settings().pr_code_suggestions.commitable_code_suggestions + ): + suggestion[ + 'improved_code' + ] = "" # we need 'existing_code' to locate the code in the PR else: suggestion['existing_code'] = "" except Exception as e: - get_logger().error(f"Error processing suggestion {i + 1}, error: {e}") + get_logger().error( + f"Error processing suggestion {i + 1}, error: {e}" + ) @staticmethod def _truncate_if_needed(suggestion): - max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0) - suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "") + max_code_suggestion_length = get_settings().get( + "PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0 + ) + suggestion_truncation_message = get_settings().get( + "PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "" + ) if max_code_suggestion_length > 0: if len(suggestion['improved_code']) > max_code_suggestion_length: - get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} " - f"characters to {max_code_suggestion_length} characters") - suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length] + get_logger().info( + f"Truncated suggestion from {len(suggestion['improved_code'])} " + f"characters to {max_code_suggestion_length} characters" + ) + suggestion['improved_code'] = suggestion['improved_code'][ + :max_code_suggestion_length + ] suggestion['improved_code'] += f"\n{suggestion_truncation_message}" return suggestion def _prepare_pr_code_suggestions(self, predictions: str) -> Dict: - data = load_yaml(predictions.strip(), - keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"], - first_key="code_suggestions", last_key="label") + data = load_yaml( + predictions.strip(), + keys_fix_yaml=[ + "relevant_file", + "suggestion_content", + "existing_code", + "improved_code", + ], + first_key="code_suggestions", + last_key="label", + ) if isinstance(data, list): data = {'code_suggestions': data} @@ -507,24 +696,35 @@ class PRCodeSuggestions: if key not in suggestion: is_valid_keys = False get_logger().debug( - f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}") + f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}" + ) break if not is_valid_keys: continue - if get_settings().get("pr_code_suggestions.focus_only_on_problems", False): + if get_settings().get( + "pr_code_suggestions.focus_only_on_problems", False + ): CRITICAL_LABEL = 'critical' - if CRITICAL_LABEL in suggestion['label'].lower(): # we want the published labels to be less declarative + if ( + CRITICAL_LABEL in suggestion['label'].lower() + ): # we want the published labels to be less declarative suggestion['label'] = 'possible issue' if suggestion['one_sentence_summary'] in one_sentence_summary_list: - get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}") + get_logger().debug( + f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}" + ) continue - if 'const' in suggestion['suggestion_content'] and 'instead' in suggestion[ - 'suggestion_content'] and 'let' in suggestion['suggestion_content']: + if ( + 'const' in suggestion['suggestion_content'] + and 'instead' in suggestion['suggestion_content'] + and 'let' in suggestion['suggestion_content'] + ): get_logger().debug( - f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}") + f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}" + ) continue if ('existing_code' in suggestion) and ('improved_code' in suggestion): @@ -533,9 +733,12 @@ class PRCodeSuggestions: suggestion_list.append(suggestion) else: get_logger().info( - f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}") + f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}" + ) except Exception as e: - get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}") + get_logger().error( + f"Error processing suggestion {i + 1}: {suggestion}, error: {e}" + ) data['code_suggestions'] = suggestion_list return data @@ -546,46 +749,72 @@ class PRCodeSuggestions: if not data['code_suggestions']: get_logger().info('No suggestions found to improve this PR.') if self.progress_response: - return self.git_provider.edit_comment(self.progress_response, - body='No suggestions found to improve this PR.') + return self.git_provider.edit_comment( + self.progress_response, + body='No suggestions found to improve this PR.', + ) else: - return self.git_provider.publish_comment('No suggestions found to improve this PR.') + return self.git_provider.publish_comment( + 'No suggestions found to improve this PR.' + ) for d in data['code_suggestions']: try: if get_settings().config.verbosity_level >= 2: get_logger().info(f"suggestion: {d}") relevant_file = d['relevant_file'].strip() - relevant_lines_start = int(d['relevant_lines_start']) # absolute position + relevant_lines_start = int( + d['relevant_lines_start'] + ) # absolute position relevant_lines_end = int(d['relevant_lines_end']) content = d['suggestion_content'].rstrip() new_code_snippet = d['improved_code'].rstrip() label = d['label'].strip() if new_code_snippet: - new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet) + new_code_snippet = self.dedent_code( + relevant_file, relevant_lines_start, new_code_snippet + ) if d.get('score'): - body = f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n" + new_code_snippet + "\n```" + body = ( + f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n" + + new_code_snippet + + "\n```" + ) else: - body = f"**Suggestion:** {content} [{label}]\n```suggestion\n" + new_code_snippet + "\n```" - code_suggestions.append({'body': body, 'relevant_file': relevant_file, - 'relevant_lines_start': relevant_lines_start, - 'relevant_lines_end': relevant_lines_end, - 'original_suggestion': d}) + body = ( + f"**Suggestion:** {content} [{label}]\n```suggestion\n" + + new_code_snippet + + "\n```" + ) + code_suggestions.append( + { + 'body': body, + 'relevant_file': relevant_file, + 'relevant_lines_start': relevant_lines_start, + 'relevant_lines_end': relevant_lines_end, + 'original_suggestion': d, + } + ) except Exception: get_logger().info(f"Could not parse suggestion: {d}") is_successful = self.git_provider.publish_code_suggestions(code_suggestions) if not is_successful: - get_logger().info("Failed to publish code suggestions, trying to publish each suggestion separately") + get_logger().info( + "Failed to publish code suggestions, trying to publish each suggestion separately" + ) for code_suggestion in code_suggestions: self.git_provider.publish_code_suggestions([code_suggestion]) def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet): try: # dedent code snippet - self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \ + self.diff_files = ( + self.git_provider.diff_files + if self.git_provider.diff_files else self.git_provider.get_diff_files() + ) original_initial_line = None for file in self.diff_files: if file.filename.strip() == relevant_file: @@ -594,29 +823,44 @@ class PRCodeSuggestions: if relevant_lines_start > len(file_lines): get_logger().warning( "Could not dedent code snippet, because relevant_lines_start is out of range", - artifact={'filename': file.filename, - 'file_content': file.head_file, - 'relevant_lines_start': relevant_lines_start, - 'new_code_snippet': new_code_snippet}) + artifact={ + 'filename': file.filename, + 'file_content': file.head_file, + 'relevant_lines_start': relevant_lines_start, + 'new_code_snippet': new_code_snippet, + }, + ) return new_code_snippet else: original_initial_line = file_lines[relevant_lines_start - 1] else: - get_logger().warning("Could not dedent code snippet, because head_file is missing", - artifact={'filename': file.filename, - 'relevant_lines_start': relevant_lines_start, - 'new_code_snippet': new_code_snippet}) + get_logger().warning( + "Could not dedent code snippet, because head_file is missing", + artifact={ + 'filename': file.filename, + 'relevant_lines_start': relevant_lines_start, + 'new_code_snippet': new_code_snippet, + }, + ) return new_code_snippet break if original_initial_line: suggested_initial_line = new_code_snippet.splitlines()[0] - original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip()) - suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip()) + original_initial_spaces = len(original_initial_line) - len( + original_initial_line.lstrip() + ) + suggested_initial_spaces = len(suggested_initial_line) - len( + suggested_initial_line.lstrip() + ) delta_spaces = original_initial_spaces - suggested_initial_spaces if delta_spaces > 0: - new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') + new_code_snippet = textwrap.indent( + new_code_snippet, delta_spaces * " " + ).rstrip('\n') except Exception as e: - get_logger().error(f"Error when dedenting code snippet for file {relevant_file}, error: {e}") + get_logger().error( + f"Error when dedenting code snippet for file {relevant_file}, error: {e}" + ) return new_code_snippet @@ -644,42 +888,72 @@ class PRCodeSuggestions: # find the first letter in the line that starts with a valid letter for j, char in enumerate(line): if not char.isdigit(): - patches_diff_lines[i] = line[j + 1:] + patches_diff_lines[i] = line[j + 1 :] break - self.patches_diff_list_no_line_numbers.append('\n'.join(patches_diff_lines)) + self.patches_diff_list_no_line_numbers.append( + '\n'.join(patches_diff_lines) + ) return self.patches_diff_list_no_line_numbers except Exception as e: - get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}") + get_logger().error( + f"Error removing line numbers from patches_diff_list, error: {e}" + ) return patches_diff_list async def _prepare_prediction_extended(self, model: str) -> dict: - self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model, - max_calls=get_settings().pr_code_suggestions.max_number_of_calls) + self.patches_diff_list = get_pr_multi_diffs( + self.git_provider, + self.token_handler, + model, + max_calls=get_settings().pr_code_suggestions.max_number_of_calls, + ) # create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections - self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list) + self.patches_diff_list_no_line_numbers = self.remove_line_numbers( + self.patches_diff_list + ) if self.patches_diff_list: - get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}") + get_logger().info( + f"Number of PR chunk calls: {len(self.patches_diff_list)}" + ) get_logger().debug(f"PR diff:", artifact=self.patches_diff_list) # parallelize calls to AI: if get_settings().pr_code_suggestions.parallel_calls: prediction_list = await asyncio.gather( - *[self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) for - patches_diff, patches_diff_no_line_numbers in - zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers)]) + *[ + self._get_prediction( + model, patches_diff, patches_diff_no_line_numbers + ) + for patches_diff, patches_diff_no_line_numbers in zip( + self.patches_diff_list, + self.patches_diff_list_no_line_numbers, + ) + ] + ) self.prediction_list = prediction_list else: prediction_list = [] - for patches_diff, patches_diff_no_line_numbers in zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers): - prediction = await self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) + for patches_diff, patches_diff_no_line_numbers in zip( + self.patches_diff_list, self.patches_diff_list_no_line_numbers + ): + prediction = await self._get_prediction( + model, patches_diff, patches_diff_no_line_numbers + ) prediction_list.append(prediction) data = {"code_suggestions": []} - for j, predictions in enumerate(prediction_list): # each call adds an element to the list + for j, predictions in enumerate( + prediction_list + ): # each call adds an element to the list if "code_suggestions" in predictions: - score_threshold = max(1, int(get_settings().pr_code_suggestions.suggestions_score_threshold)) + score_threshold = max( + 1, + int( + get_settings().pr_code_suggestions.suggestions_score_threshold + ), + ) for i, prediction in enumerate(predictions["code_suggestions"]): try: score = int(prediction.get("score", 1)) @@ -688,10 +962,13 @@ class PRCodeSuggestions: else: get_logger().info( f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}", - artifact=prediction) + artifact=prediction, + ) except Exception as e: - get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}", - artifact={"prediction": prediction}) + get_logger().error( + f"Error getting PR diff for suggestion {i} in call {j}, error: {e}", + artifact={"prediction": prediction}, + ) self.data = data else: get_logger().warning(f"Empty PR diff list") @@ -706,7 +983,10 @@ class PRCodeSuggestions: pr_body += "No suggestions found to improve this PR." return pr_body - if get_settings().pr_code_suggestions.enable_intro_text and get_settings().config.is_auto_command: + if ( + get_settings().pr_code_suggestions.enable_intro_text + and get_settings().config.is_auto_command + ): pr_body += "Explore these optional code suggestions:\n\n" language_extension_map_org = get_settings().language_extension_map_org @@ -731,17 +1011,25 @@ class PRCodeSuggestions: # sort suggestions_labels by the suggestion with the highest score suggestions_labels = dict( - sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True)) + sorted( + suggestions_labels.items(), + key=lambda x: max([s['score'] for s in x[1]]), + reverse=True, + ) + ) # sort the suggestions inside each label group by score for label, suggestions in suggestions_labels.items(): - suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True) + suggestions_labels[label] = sorted( + suggestions, key=lambda x: x['score'], reverse=True + ) counter_suggestions = 0 for label, suggestions in suggestions_labels.items(): num_suggestions = len(suggestions) - pr_body += f"""{label.capitalize()}\n""" + pr_body += ( + f"""{label.capitalize()}\n""" + ) for i, suggestion in enumerate(suggestions): - relevant_file = suggestion['relevant_file'].strip() relevant_lines_start = int(suggestion['relevant_lines_start']) relevant_lines_end = int(suggestion['relevant_lines_end']) @@ -752,21 +1040,25 @@ class PRCodeSuggestions: range_str = f"[{relevant_lines_start}-{relevant_lines_end}]" try: - code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start, - relevant_lines_end) + code_snippet_link = self.git_provider.get_line_link( + relevant_file, relevant_lines_start, relevant_lines_end + ) except: code_snippet_link = "" # add html table for each suggestion suggestion_content = suggestion['suggestion_content'].rstrip() CHAR_LIMIT_PER_LINE = 84 - suggestion_content = insert_br_after_x_chars(suggestion_content, CHAR_LIMIT_PER_LINE) + suggestion_content = insert_br_after_x_chars( + suggestion_content, CHAR_LIMIT_PER_LINE + ) # pr_body += f"
{suggestion_content}" existing_code = suggestion['existing_code'].rstrip() + "\n" improved_code = suggestion['improved_code'].rstrip() + "\n" - diff = difflib.unified_diff(existing_code.split('\n'), - improved_code.split('\n'), n=999) + diff = difflib.unified_diff( + existing_code.split('\n'), improved_code.split('\n'), n=999 + ) patch_orig = "\n".join(diff) patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') @@ -776,10 +1068,14 @@ class PRCodeSuggestions: pr_body += f"""\n\n""" else: pr_body += f"""\n\n""" - suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.') + suggestion_summary = ( + suggestion['one_sentence_summary'].strip().rstrip('.') + ) if "'<" in suggestion_summary and ">'" in suggestion_summary: # escape the '<' and '>' characters, otherwise they are interpreted as html tags - get_logger().info(f"Escaped suggestion summary: {suggestion_summary}") + get_logger().info( + f"Escaped suggestion summary: {suggestion_summary}" + ) suggestion_summary = suggestion_summary.replace("'<", "`<") suggestion_summary = suggestion_summary.replace(">'", ">`") if '`' in suggestion_summary: @@ -815,12 +1111,18 @@ class PRCodeSuggestions: pr_body += """""" return pr_body except Exception as e: - get_logger().info(f"Failed to publish summarized code suggestions, error: {e}") + get_logger().info( + f"Failed to publish summarized code suggestions, error: {e}" + ) return "" def get_score_str(self, score: int) -> str: - th_high = get_settings().pr_code_suggestions.get('new_score_mechanism_th_high', 9) - th_medium = get_settings().pr_code_suggestions.get('new_score_mechanism_th_medium', 7) + th_high = get_settings().pr_code_suggestions.get( + 'new_score_mechanism_th_high', 9 + ) + th_medium = get_settings().pr_code_suggestions.get( + 'new_score_mechanism_th_medium', 7 + ) if score >= th_high: return "高" elif score >= th_medium: @@ -828,12 +1130,14 @@ class PRCodeSuggestions: else: # score < 7 return "低" - async def self_reflect_on_suggestions(self, - suggestion_list: List, - patches_diff: str, - model: str, - prev_suggestions_str: str = "", - dedicated_prompt: str = "") -> str: + async def self_reflect_on_suggestions( + self, + suggestion_list: List, + patches_diff: str, + model: str, + prev_suggestions_str: str = "", + dedicated_prompt: str = "", + ) -> str: if not suggestion_list: return "" @@ -842,31 +1146,44 @@ class PRCodeSuggestions: for i, suggestion in enumerate(suggestion_list): suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n' - variables = {'suggestion_list': suggestion_list, - 'suggestion_str': suggestion_str, - "diff": patches_diff, - 'num_code_suggestions': len(suggestion_list), - 'prev_suggestions_str': prev_suggestions_str, - "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False)} + variables = { + 'suggestion_list': suggestion_list, + 'suggestion_str': suggestion_str, + "diff": patches_diff, + 'num_code_suggestions': len(suggestion_list), + 'prev_suggestions_str': prev_suggestions_str, + "is_ai_metadata": get_settings().get( + "config.enable_ai_metadata", False + ), + 'duplicate_prompt_examples': get_settings().config.get( + 'duplicate_prompt_examples', False + ), + } environment = Environment(undefined=StrictUndefined) if dedicated_prompt: system_prompt_reflect = environment.from_string( - get_settings().get(dedicated_prompt).system).render(variables) + get_settings().get(dedicated_prompt).system + ).render(variables) user_prompt_reflect = environment.from_string( - get_settings().get(dedicated_prompt).user).render(variables) + get_settings().get(dedicated_prompt).user + ).render(variables) else: system_prompt_reflect = environment.from_string( - get_settings().pr_code_suggestions_reflect_prompt.system).render(variables) + get_settings().pr_code_suggestions_reflect_prompt.system + ).render(variables) user_prompt_reflect = environment.from_string( - get_settings().pr_code_suggestions_reflect_prompt.user).render(variables) + get_settings().pr_code_suggestions_reflect_prompt.user + ).render(variables) with get_logger().contextualize(command="self_reflect_on_suggestions"): - response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model, - system=system_prompt_reflect, - user=user_prompt_reflect) + ( + response_reflect, + finish_reason_reflect, + ) = await self.ai_handler.chat_completion( + model=model, system=system_prompt_reflect, user=user_prompt_reflect + ) except Exception as e: get_logger().info(f"Could not reflect on suggestions, error: {e}") return "" - return response_reflect \ No newline at end of file + return response_reflect diff --git a/apps/utils/pr_agent/tools/pr_config.py b/apps/utils/pr_agent/tools/pr_config.py index a00e015..cfee29e 100644 --- a/apps/utils/pr_agent/tools/pr_config.py +++ b/apps/utils/pr_agent/tools/pr_config.py @@ -9,6 +9,7 @@ class PRConfig: """ The PRConfig class is responsible for listing all configuration options available for the user. """ + def __init__(self, pr_url: str, args=None, ai_handler=None): """ Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request. @@ -34,20 +35,43 @@ class PRConfig: conf_settings = Dynaconf(settings_files=[conf_file]) configuration_headers = [header.lower() for header in conf_settings.keys()] relevant_configs = { - header: configs for header, configs in get_settings().to_dict().items() - if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers + header: configs + for header, configs in get_settings().to_dict().items() + if (header.lower().startswith("pr_") or header.lower().startswith("config")) + and header.lower() in configuration_headers } - skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", - 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS', - 'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token', - 'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token'] + skip_keys = [ + 'ai_disclaimer', + 'ai_disclaimer_title', + 'ANALYTICS_FOLDER', + 'secret_provider', + "skip_keys", + "app_id", + "redirect", + 'trial_prefix_message', + 'no_eligible_message', + 'identity_provider', + 'ALLOWED_REPOS', + 'APP_NAME', + 'PERSONAL_ACCESS_TOKEN', + 'shared_secret', + 'key', + 'AWS_ACCESS_KEY_ID', + 'AWS_SECRET_ACCESS_KEY', + 'user_token', + 'private_key', + 'private_key_id', + 'client_id', + 'client_secret', + 'token', + 'bearer_token', + ] extra_skip_keys = get_settings().config.get('config.skip_keys', []) if extra_skip_keys: skip_keys.extend(extra_skip_keys) skip_keys_lower = [key.lower() for key in skip_keys] - markdown_text = "
🛠️ PR-Agent Configurations: \n\n" markdown_text += f"\n\n```yaml\n\n" for header, configs in relevant_configs.items(): @@ -61,5 +85,7 @@ class PRConfig: markdown_text += " " markdown_text += "\n```" markdown_text += "\n
\n" - get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text) + get_logger().info( + f"Possible Configurations outputted to PR comment", artifact=markdown_text + ) return markdown_text diff --git a/apps/utils/pr_agent/tools/pr_description.py b/apps/utils/pr_agent/tools/pr_description.py index 89a589b..f929e81 100644 --- a/apps/utils/pr_agent/tools/pr_description.py +++ b/apps/utils/pr_agent/tools/pr_description.py @@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler -from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD, - get_pr_diff, - get_pr_diff_multiple_patchs, - retry_with_fallback_models) +from utils.pr_agent.algo.pr_processing import ( + OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD, + get_pr_diff, + get_pr_diff_multiple_patchs, + retry_with_fallback_models, +) from utils.pr_agent.algo.token_handler import TokenHandler -from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens, - get_max_tokens, get_user_labels, load_yaml, - set_custom_labels, - show_relevant_configurations) +from utils.pr_agent.algo.utils import ( + ModelType, + PRDescriptionHeader, + clip_tokens, + get_max_tokens, + get_user_labels, + load_yaml, + set_custom_labels, + show_relevant_configurations, +) from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context) +from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context from utils.pr_agent.git_providers.git_provider import get_main_pr_language from utils.pr_agent.log import get_logger from utils.pr_agent.servers.help import HelpMessage -from utils.pr_agent.tools.ticket_pr_compliance_check import ( - extract_and_cache_pr_tickets) +from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets class PRDescription: - def __init__(self, pr_url: str, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. @@ -44,11 +55,22 @@ class PRDescription: self.git_provider.get_languages(), self.git_provider.get_files() ) self.pr_id = self.git_provider.get_pr_id() - self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"] + self.keys_fix = [ + "filename:", + "language:", + "changes_summary:", + "changes_title:", + "description:", + "title:", + ] - if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported( - "gfm_markdown"): - get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.") + if ( + get_settings().pr_description.enable_semantic_files_types + and not self.git_provider.is_supported("gfm_markdown") + ): + get_logger().debug( + f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported." + ) get_settings().pr_description.enable_semantic_files_types = False # Initialize the AI handler @@ -56,7 +78,9 @@ class PRDescription: self.ai_handler.main_pr_language = self.main_pr_language # Initialize the variables dictionary - self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8) + self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get( + "collapsible_file_list_threshold", 8 + ) self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), @@ -69,8 +93,11 @@ class PRDescription: "custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function "enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types, "related_tickets": "", - "include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD, - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), + "include_file_summary_changes": len(self.git_provider.get_diff_files()) + <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD, + 'duplicate_prompt_examples': get_settings().config.get( + 'duplicate_prompt_examples', False + ), } self.user_description = self.git_provider.get_user_description() @@ -91,10 +118,14 @@ class PRDescription: async def run(self): try: get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}") - relevant_configs = {'pr_description': dict(get_settings().pr_description), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_description': dict(get_settings().pr_description), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) - if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): + if get_settings().config.publish_output and not get_settings().config.get( + 'is_auto_command', False + ): self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True) # ticket extraction if exists @@ -119,40 +150,73 @@ class PRDescription: get_logger().debug(f"Publishing labels disabled") if get_settings().pr_description.use_description_markers: - pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers() + ( + pr_title, + pr_body, + changes_walkthrough, + pr_file_changes, + ) = self._prepare_pr_answer_with_markers() else: - pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer() - if not self.git_provider.is_supported( - "publish_file_comments") or not get_settings().pr_description.inline_file_summary: + ( + pr_title, + pr_body, + changes_walkthrough, + pr_file_changes, + ) = self._prepare_pr_answer() + if ( + not self.git_provider.is_supported("publish_file_comments") + or not get_settings().pr_description.inline_file_summary + ): pr_body += "\n\n" + changes_walkthrough - get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body}) + get_logger().debug( + "PR output", artifact={"title": pr_title, "body": pr_body} + ) # Add help text if gfm_markdown is supported - if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text: + if ( + self.git_provider.is_supported("gfm_markdown") + and get_settings().pr_description.enable_help_text + ): pr_body += "
\n\n
✨ 工具使用指南:
\n\n" pr_body += HelpMessage.get_describe_usage_guide() pr_body += "\n
\n" - elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"): + elif ( + get_settings().pr_description.enable_help_comment + and self.git_provider.is_supported("gfm_markdown") + ): if isinstance(self.git_provider, GithubProvider): - pr_body += ('\n\n___\n\n>
需要帮助?
  • Type /help 如何 ... ' - '关于PR-Agent使用的任何问题,请在评论区留言.
  • 查看一下 ' - 'documentation ' - '了解更多.
  • ') - else: # gitlab - pr_body += ("\n\n___\n\n
    需要帮助?- Type /help 如何 ... 在评论中 " - "关于PR-Agent使用的任何问题请在此发帖.
    - 查看一下 " - "documentation 了解更多.
    ") + pr_body += ( + '\n\n___\n\n>
    需要帮助?
  • Type /help 如何 ... ' + '关于PR-Agent使用的任何问题,请在评论区留言.
  • 查看一下 ' + 'documentation ' + '了解更多.
  • ' + ) + else: # gitlab + pr_body += ( + "\n\n___\n\n
    需要帮助?- Type /help 如何 ... 在评论中 " + "关于PR-Agent使用的任何问题请在此发帖.
    - 查看一下 " + "documentation 了解更多.
    " + ) # elif get_settings().pr_description.enable_help_comment: # pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information' # Output the relevant configurations if enabled - if get_settings().get('config', {}).get('output_relevant_configurations', False): - pr_body += show_relevant_configurations(relevant_section='pr_description') + if ( + get_settings() + .get('config', {}) + .get('output_relevant_configurations', False) + ): + pr_body += show_relevant_configurations( + relevant_section='pr_description' + ) if get_settings().config.publish_output: - # publish labels - if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"): + if ( + get_settings().pr_description.publish_labels + and pr_labels + and self.git_provider.is_supported("get_labels") + ): original_labels = self.git_provider.get_pr_labels(update=True) get_logger().debug(f"original labels", artifact=original_labels) user_labels = get_user_labels(original_labels) @@ -165,20 +229,29 @@ class PRDescription: # publish description if get_settings().pr_description.publish_description_as_comment: - full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}" - if get_settings().pr_description.publish_description_as_comment_persistent: - self.git_provider.publish_persistent_comment(full_markdown_description, - initial_header="## Title", - update_header=True, - name="describe", - final_update_message=False, ) + full_markdown_description = ( + f"## Title\n\n{pr_title}\n\n___\n{pr_body}" + ) + if ( + get_settings().pr_description.publish_description_as_comment_persistent + ): + self.git_provider.publish_persistent_comment( + full_markdown_description, + initial_header="## Title", + update_header=True, + name="describe", + final_update_message=False, + ) else: self.git_provider.publish_comment(full_markdown_description) else: self.git_provider.publish_description(pr_title, pr_body) # publish final update message - if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)): + if ( + get_settings().pr_description.final_update_message + and not get_settings().config.get('is_auto_command', False) + ): latest_commit_url = self.git_provider.get_latest_commit_url() if latest_commit_url: pr_url = self.git_provider.get_pr_url() @@ -186,22 +259,40 @@ class PRDescription: self.git_provider.publish_comment(update_comment) self.git_provider.remove_initial_comment() else: - get_logger().info('PR description, but not published since publish_output is False.') + get_logger().info( + 'PR description, but not published since publish_output is False.' + ) get_settings().data = {"artifact": pr_body} return except Exception as e: - get_logger().error(f"Error generating PR description {self.pr_id}: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error generating PR description {self.pr_id}: {e}", + artifact={"traceback": traceback.format_exc()}, + ) return "" async def _prepare_prediction(self, model: str) -> None: - if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description: - get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction") + if ( + get_settings().pr_description.use_description_markers + and 'pr_agent:' not in self.user_description + ): + get_logger().info( + "Markers were enabled, but user description does not contain markers. skipping AI prediction" + ) return None - large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() - output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True) + large_pr_handling = ( + get_settings().pr_description.enable_large_pr_handling + and "pr_description_only_files_prompts" in get_settings() + ) + output = get_pr_diff( + self.git_provider, + self.token_handler, + model, + large_pr_handling=large_pr_handling, + return_remaining_files=True, + ) if isinstance(output, tuple): patches_diff, remaining_files_list = output else: @@ -213,14 +304,18 @@ class PRDescription: if patches_diff: # generate the prediction get_logger().debug(f"PR diff", artifact=self.patches_diff) - self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt") + self.prediction = await self._get_prediction( + model, patches_diff, prompt="pr_description_prompt" + ) # extend the prediction with additional files not shown if get_settings().pr_description.enable_semantic_files_types: self.prediction = await self.extend_uncovered_files(self.prediction) else: - get_logger().error(f"Error getting PR diff {self.pr_id}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error getting PR diff {self.pr_id}", + artifact={"traceback": traceback.format_exc()}, + ) self.prediction = None else: # get the diff in multiple patches, with the token handler only for the files prompt @@ -231,9 +326,16 @@ class PRDescription: get_settings().pr_description_only_files_prompts.system, get_settings().pr_description_only_files_prompts.user, ) - (patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, - files_in_patches_list) = get_pr_diff_multiple_patchs( - self.git_provider, token_handler_only_files_prompt, model) + ( + patches_compressed_list, + total_tokens_list, + deleted_files_list, + remaining_files_list, + file_dict, + files_in_patches_list, + ) = get_pr_diff_multiple_patchs( + self.git_provider, token_handler_only_files_prompt, model + ) # get the files prediction for each patch if not get_settings().pr_description.async_ai_calls: @@ -241,8 +343,9 @@ class PRDescription: for i, patches in enumerate(patches_compressed_list): # sync calls patches_diff = "\n".join(patches) get_logger().debug(f"PR diff number {i + 1} for describe files") - prediction_files = await self._get_prediction(model, patches_diff, - prompt="pr_description_only_files_prompts") + prediction_files = await self._get_prediction( + model, patches_diff, prompt="pr_description_only_files_prompts" + ) results.append(prediction_files) else: # async calls tasks = [] @@ -251,34 +354,52 @@ class PRDescription: patches_diff = "\n".join(patches) get_logger().debug(f"PR diff number {i + 1} for describe files") task = asyncio.create_task( - self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts")) + self._get_prediction( + model, + patches_diff, + prompt="pr_description_only_files_prompts", + ) + ) tasks.append(task) # Wait for all tasks to complete results = await asyncio.gather(*tasks) file_description_str_list = [] for i, result in enumerate(results): - prediction_files = result.strip().removeprefix('```yaml').strip('`').strip() - if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'): - prediction_files = prediction_files.removeprefix('pr_files:').strip() + prediction_files = ( + result.strip().removeprefix('```yaml').strip('`').strip() + ) + if load_yaml( + prediction_files, keys_fix_yaml=self.keys_fix + ) and prediction_files.startswith('pr_files'): + prediction_files = prediction_files.removeprefix( + 'pr_files:' + ).strip() file_description_str_list.append(prediction_files) else: - get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files") + get_logger().debug( + f"failed to generate predictions in iteration {i + 1} for describe files" + ) # generate files_walkthrough string, with proper token handling token_handler_only_description_prompt = TokenHandler( self.git_provider.pr, self.vars, get_settings().pr_description_only_description_prompts.system, - get_settings().pr_description_only_description_prompts.user) + get_settings().pr_description_only_description_prompts.user, + ) files_walkthrough = "\n".join(file_description_str_list) files_walkthrough_prompt = copy.deepcopy(files_walkthrough) MAX_EXTRA_FILES_TO_PROMPT = 50 if remaining_files_list: - files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:" + files_walkthrough_prompt += ( + "\n\nNo more token budget. Additional unprocessed files:" + ) for i, file in enumerate(remaining_files_list): files_walkthrough_prompt += f"\n- {file}" if i >= MAX_EXTRA_FILES_TO_PROMPT: - get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}") + get_logger().debug( + f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}" + ) files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more" break if deleted_files_list: @@ -286,32 +407,57 @@ class PRDescription: for i, file in enumerate(deleted_files_list): files_walkthrough_prompt += f"\n- {file}" if i >= MAX_EXTRA_FILES_TO_PROMPT: - get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}") + get_logger().debug( + f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}" + ) files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more" break tokens_files_walkthrough = len( - token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt)) - total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough + token_handler_only_description_prompt.encoder.encode( + files_walkthrough_prompt + ) + ) + total_tokens = ( + token_handler_only_description_prompt.prompt_tokens + + tokens_files_walkthrough + ) max_tokens_model = get_max_tokens(model) if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: # clip files_walkthrough to git the tokens within the limit - files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt, - max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens, - num_input_tokens=tokens_files_walkthrough) + files_walkthrough_prompt = clip_tokens( + files_walkthrough_prompt, + max_tokens_model + - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD + - token_handler_only_description_prompt.prompt_tokens, + num_input_tokens=tokens_files_walkthrough, + ) # PR header inference - get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt) - prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt, - prompt="pr_description_only_description_prompts") - prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip() + get_logger().debug( + f"PR diff only description", artifact=files_walkthrough_prompt + ) + prediction_headers = await self._get_prediction( + model, + patches_diff=files_walkthrough_prompt, + prompt="pr_description_only_description_prompts", + ) + prediction_headers = ( + prediction_headers.strip().removeprefix('```yaml').strip('`').strip() + ) # extend the tables with the files not shown - files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough) + files_walkthrough_extended = await self.extend_uncovered_files( + files_walkthrough + ) # final processing - self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended + self.prediction = ( + prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended + ) if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix): - get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}") + get_logger().error( + f"Error getting valid YAML in large PR handling for describe {self.pr_id}" + ) if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix): get_logger().debug(f"Using only headers for describe {self.pr_id}") self.prediction = prediction_headers @@ -321,12 +467,17 @@ class PRDescription: prediction = original_prediction # get the original prediction filenames - original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix) + original_prediction_loaded = load_yaml( + original_prediction, keys_fix_yaml=self.keys_fix + ) if isinstance(original_prediction_loaded, list): original_prediction_dict = {"pr_files": original_prediction_loaded} else: original_prediction_dict = original_prediction_loaded - filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])] + filenames_predicted = [ + file['filename'].strip() + for file in original_prediction_dict.get('pr_files', []) + ] # extend the prediction with additional files not included in the original prediction pr_files = self.git_provider.get_diff_files() @@ -349,7 +500,9 @@ class PRDescription: additional files """ prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() - get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}") + get_logger().debug( + f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}" + ) break extra_file_yaml = f"""\ @@ -364,10 +517,18 @@ class PRDescription: # merge the two dictionaries if counter_extra_files > 0: - get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction") - prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix) - if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): - original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) + get_logger().info( + f"Adding {counter_extra_files} unprocessed extra files to table prediction" + ) + prediction_extra_dict = load_yaml( + prediction_extra, keys_fix_yaml=self.keys_fix + ) + if isinstance(original_prediction_dict, dict) and isinstance( + prediction_extra_dict, dict + ): + original_prediction_dict["pr_files"].extend( + prediction_extra_dict["pr_files"] + ) new_yaml = yaml.dump(original_prediction_dict) if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix): prediction = new_yaml @@ -379,11 +540,12 @@ class PRDescription: get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}") return original_prediction - async def extend_additional_files(self, remaining_files_list) -> str: prediction = self.prediction try: - original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix) + original_prediction_dict = load_yaml( + self.prediction, keys_fix_yaml=self.keys_fix + ) prediction_extra = "pr_files:" for file in remaining_files_list: extra_file_yaml = f"""\ @@ -397,10 +559,16 @@ class PRDescription: additional files (token-limit) """ prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() - prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix) + prediction_extra_dict = load_yaml( + prediction_extra, keys_fix_yaml=self.keys_fix + ) # merge the two dictionaries - if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): - original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) + if isinstance(original_prediction_dict, dict) and isinstance( + prediction_extra_dict, dict + ): + original_prediction_dict["pr_files"].extend( + prediction_extra_dict["pr_files"] + ) new_yaml = yaml.dump(original_prediction_dict) if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix): prediction = new_yaml @@ -409,7 +577,9 @@ class PRDescription: get_logger().error(f"Error extending additional files {self.pr_id}: {e}") return self.prediction - async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: + async def _get_prediction( + self, model: str, patches_diff: str, prompt="pr_description_prompt" + ) -> str: variables = copy.deepcopy(self.vars) variables["diff"] = patches_diff # update diff @@ -417,14 +587,18 @@ class PRDescription: set_custom_labels(variables, self.git_provider) self.variables = variables - system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables) - user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables) + system_prompt = environment.from_string( + get_settings().get(prompt, {}).get("system", "") + ).render(self.variables) + user_prompt = environment.from_string( + get_settings().get(prompt, {}).get("user", "") + ).render(self.variables) response, finish_reason = await self.ai_handler.chat_completion( model=model, temperature=get_settings().config.temperature, system=system_prompt, - user=user_prompt + user=user_prompt, ) return response @@ -433,7 +607,10 @@ class PRDescription: # Load the AI prediction data into a dictionary self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix) - if get_settings().pr_description.add_original_user_description and self.user_description: + if ( + get_settings().pr_description.add_original_user_description + and self.user_description + ): self.data["User Description"] = self.user_description # re-order keys @@ -459,7 +636,11 @@ class PRDescription: pr_labels = self.data['labels'] elif type(self.data['labels']) == str: pr_labels = self.data['labels'].split(',') - elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels: + elif ( + 'type' in self.data + and self.data['type'] + and get_settings().pr_description.publish_labels + ): if type(self.data['type']) == list: pr_labels = self.data['type'] elif type(self.data['type']) == str: @@ -474,7 +655,9 @@ class PRDescription: if label_i in d: pr_labels[i] = d[label_i] except Exception as e: - get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") + get_logger().error( + f"Error converting labels to original case {self.pr_id}: {e}" + ) return pr_labels def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]: @@ -482,13 +665,13 @@ class PRDescription: # Remove the 'PR Title' key from the dictionary ai_title = self.data.pop('title', self.vars["title"]) - if (not get_settings().pr_description.generate_ai_title): + if not get_settings().pr_description.generate_ai_title: # Assign the original PR title to the 'title' variable title = self.vars["title"] else: # Assign the value of the 'PR Title' key to 'title' variable title = ai_title - + body = self.user_description if get_settings().pr_description.include_generated_by_header: ai_header = f"### 🤖 Generated by PR Agent at {self.git_provider.last_commit_id.sha}\n\n" @@ -514,8 +697,9 @@ class PRDescription: pr_file_changes = [] if ai_walkthrough and not re.search(r'', body): try: - walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm, - self.file_label_dict) + walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction( + walkthrough_gfm, self.file_label_dict + ) body = body.replace('pr_agent:walkthrough', walkthrough_gfm) except Exception as e: get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}") @@ -545,7 +729,7 @@ class PRDescription: # Remove the 'PR Title' key from the dictionary ai_title = self.data.pop('title', self.vars["title"]) - if (not get_settings().pr_description.generate_ai_title): + if not get_settings().pr_description.generate_ai_title: # Assign the original PR title to the 'title' variable title = self.vars["title"] else: @@ -575,13 +759,20 @@ class PRDescription: pr_body += f'- `{filename}`: {description}\n' if self.git_provider.is_supported("gfm_markdown"): pr_body += "
    \n" - elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types: - changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value) + elif ( + 'pr_files' in key.lower() + and get_settings().pr_description.enable_semantic_files_types + ): + changes_walkthrough, pr_file_changes = self.process_pr_files_prediction( + changes_walkthrough, value + ) changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}" elif key.lower().strip() == 'description': if isinstance(value, list): value = ', '.join(v.rstrip() for v in value) - value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space + value = value.replace( + '\n-', '\n\n-' + ).strip() # makes the bullet points more readable by adding double space pr_body += f"{value}\n" else: # if the value is a list, join its items by comma @@ -591,24 +782,37 @@ class PRDescription: if idx < len(self.data) - 1: pr_body += "\n\n___\n\n" - return title, pr_body, changes_walkthrough, pr_file_changes, + return ( + title, + pr_body, + changes_walkthrough, + pr_file_changes, + ) def _prepare_file_labels(self): file_label_dict = {} - if (not self.data or not isinstance(self.data, dict) or - 'pr_files' not in self.data or not self.data['pr_files']): + if ( + not self.data + or not isinstance(self.data, dict) + or 'pr_files' not in self.data + or not self.data['pr_files'] + ): return file_label_dict for file in self.data['pr_files']: try: required_fields = ['changes_title', 'filename', 'label'] if not all(field in file for field in required_fields): # can happen for example if a YAML generation was interrupted in the middle (no more tokens) - get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file", - artifact={"file": file}) + get_logger().warning( + f"Missing required fields in file label dict {self.pr_id}, skipping file", + artifact={"file": file}, + ) continue if not file.get('changes_title'): - get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file", - artifact={"file": file}) + get_logger().warning( + f"Empty changes title or summary in file label dict {self.pr_id}, skipping file", + artifact={"file": file}, + ) continue filename = file['filename'].replace("'", "`").replace('"', '`') changes_summary = file.get('changes_summary', "").strip() @@ -616,7 +820,9 @@ class PRDescription: label = file.get('label').strip().lower() if label not in file_label_dict: file_label_dict[label] = [] - file_label_dict[label].append((filename, changes_title, changes_summary)) + file_label_dict[label].append( + (filename, changes_title, changes_summary) + ) except Exception as e: get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}") pass @@ -640,7 +846,9 @@ class PRDescription: header = f"相关文件" delta = 75 # header += "  " * delta - pr_body += f"""{header}""" + pr_body += ( + f"""{header}""" + ) pr_body += """""" for semantic_label in value.keys(): s_label = semantic_label.strip("'").strip('"') @@ -651,14 +859,22 @@ class PRDescription: pr_body += f"""
    {len(list_tuples)} files""" else: pr_body += f""" @@ -735,6 +975,7 @@ class PRDescription: """ return pr_body + def count_chars_without_html(string): if '<' not in string: return len(string) diff --git a/apps/utils/pr_agent/tools/pr_generate_labels.py b/apps/utils/pr_agent/tools/pr_generate_labels.py index 85158e0..1eeabe7 100644 --- a/apps/utils/pr_agent/tools/pr_generate_labels.py +++ b/apps/utils/pr_agent/tools/pr_generate_labels.py @@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger class PRGenerateLabels: - def __init__(self, pr_url: str, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): """ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels corresponding to the PR using an AI model. @@ -93,7 +97,9 @@ class PRGenerateLabels: elif pr_labels: value = ', '.join(v for v in pr_labels) pr_labels_text = f"## PR Labels:\n{value}\n" - self.git_provider.publish_comment(pr_labels_text, is_temporary=False) + self.git_provider.publish_comment( + pr_labels_text, is_temporary=False + ) self.git_provider.remove_initial_comment() except Exception as e: get_logger().error(f"Error generating PR labels {self.pr_id}: {e}") @@ -137,14 +143,18 @@ class PRGenerateLabels: set_custom_labels(variables, self.git_provider) self.variables = variables - system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables) - user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables) + system_prompt = environment.from_string( + get_settings().pr_custom_labels_prompt.system + ).render(self.variables) + user_prompt = environment.from_string( + get_settings().pr_custom_labels_prompt.user + ).render(self.variables) response, finish_reason = await self.ai_handler.chat_completion( model=model, temperature=get_settings().config.temperature, system=system_prompt, - user=user_prompt + user=user_prompt, ) return response @@ -153,8 +163,6 @@ class PRGenerateLabels: # Load the AI prediction data into a dictionary self.data = load_yaml(self.prediction.strip()) - - def _prepare_labels(self) -> List[str]: pr_types = [] @@ -174,6 +182,8 @@ class PRGenerateLabels: if label_i in d: pr_types[i] = d[label_i] except Exception as e: - get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") + get_logger().error( + f"Error converting labels to original case {self.pr_id}: {e}" + ) return pr_types diff --git a/apps/utils/pr_agent/tools/pr_help_message.py b/apps/utils/pr_agent/tools/pr_help_message.py index ca83b46..fda1d0e 100644 --- a/apps/utils/pr_agent/tools/pr_help_message.py +++ b/apps/utils/pr_agent/tools/pr_help_message.py @@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context +from utils.pr_agent.git_providers import ( + BitbucketServerProvider, + GithubProvider, + get_git_provider_with_context, +) from utils.pr_agent.log import get_logger @@ -29,31 +33,50 @@ def extract_header(snippet): res = f"#{highest_header.lower().replace(' ', '-')}" return res + class PRHelpMessage: - def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False): + def __init__( + self, + pr_url: str, + args=None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + return_as_string=False, + ): self.git_provider = get_git_provider_with_context(pr_url) self.ai_handler = ai_handler() self.question_str = self.parse_args(args) self.return_as_string = return_as_string - self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5) + self.num_retrieved_snippets = get_settings().get( + 'pr_help.num_retrieved_snippets', 5 + ) if self.question_str: self.vars = { "question": self.question_str, "snippets": "", } - self.token_handler = TokenHandler(None, - self.vars, - get_settings().pr_help_prompts.system, - get_settings().pr_help_prompts.user) + self.token_handler = TokenHandler( + None, + self.vars, + get_settings().pr_help_prompts.system, + get_settings().pr_help_prompts.user, + ) async def _prepare_prediction(self, model: str): try: variables = copy.deepcopy(self.vars) environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_help_prompts.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_help_prompts.user + ).render(variables) response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + ) return response except Exception as e: get_logger().error(f"Error while preparing prediction: {e}") @@ -81,7 +104,7 @@ class PRHelpMessage: '.': '', '?': '', '!': '', - ' ': '-' + ' ': '-', } # Compile regex pattern for characters to remove @@ -90,37 +113,69 @@ class PRHelpMessage: # Perform replacements in a single pass and convert to lowercase return pattern.sub(lambda m: replacements[m.group()], cleaned).lower() except Exception: - get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header}) + get_logger().exception( + f"Error while formatting markdown header", artifacts={'header': header} + ) return "" - async def run(self): try: if self.question_str: - get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ') + get_logger().info( + f'Answering a PR question about the PR {self.git_provider.pr_url} ' + ) if not get_settings().get('openai.key'): if get_settings().config.publish_output: self.git_provider.publish_comment( - "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") + "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings" + ) else: - get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") + get_logger().error( + "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings" + ) return # current path - docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs' + docs_path = Path(__file__).parent.parent.parent / 'docs' / 'docs' # get all the 'md' files inside docs_path and its subdirectories md_files = list(docs_path.glob('**/*.md')) folders_to_exclude = ['/finetuning_benchmark/'] - files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'} - md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)] + files_to_exclude = { + 'EXAMPLE_BEST_PRACTICE.md', + 'compression_strategy.md', + '/docs/overview/index.md', + } + md_files = [ + file + for file in md_files + if not any(folder in str(file) for folder in folders_to_exclude) + and not any( + file.name == file_to_exclude + for file_to_exclude in files_to_exclude + ) + ] # sort the 'md_files' so that 'priority_files' will be at the top - priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md', - 'tools/improve.md', '/faq'] - md_files_priority = [file for file in md_files if - any(priority_string in str(file) for priority_string in priority_files_strings)] - md_files_not_priority = [file for file in md_files if file not in md_files_priority] + priority_files_strings = [ + '/docs/index.md', + '/usage-guide', + 'tools/describe.md', + 'tools/review.md', + 'tools/improve.md', + '/faq', + ] + md_files_priority = [ + file + for file in md_files + if any( + priority_string in str(file) + for priority_string in priority_files_strings + ) + ] + md_files_not_priority = [ + file for file in md_files if file not in md_files_priority + ] md_files = md_files_priority + md_files_not_priority docs_prompt = "" @@ -132,24 +187,36 @@ class PRHelpMessage: except Exception as e: get_logger().error(f"Error while reading the file {file}: {e}") token_count = self.token_handler.count_tokens(docs_prompt) - get_logger().debug(f"Token count of full documentation website: {token_count}") + get_logger().debug( + f"Token count of full documentation website: {token_count}" + ) model = get_settings().config.model if model in MAX_TOKENS: - max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt + max_tokens_full = MAX_TOKENS[ + model + ] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt else: max_tokens_full = get_max_tokens(model) delta_output = 2000 if token_count > max_tokens_full - delta_output: - get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.") - docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output) + get_logger().info( + f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message." + ) + docs_prompt = clip_tokens( + docs_prompt, max_tokens_full - delta_output + ) self.vars['snippets'] = docs_prompt.strip() # run the AI model - response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) + response = await retry_with_fallback_models( + self._prepare_prediction, model_type=ModelType.REGULAR + ) response_yaml = load_yaml(response) if isinstance(response_yaml, str): - get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is") + get_logger().warning( + f"failing to parse response: {response_yaml}, publishing the response as is" + ) if get_settings().config.publish_output: answer_str = f"### Question: \n{self.question_str}\n\n" answer_str += f"### Answer:\n\n" @@ -160,7 +227,9 @@ class PRHelpMessage: relevant_sections = response_yaml.get('relevant_sections') if not relevant_sections: - get_logger().info(f"Could not find relevant answer for the question: {self.question_str}") + get_logger().info( + f"Could not find relevant answer for the question: {self.question_str}" + ) if get_settings().config.publish_output: answer_str = f"### Question: \n{self.question_str}\n\n" answer_str += f"### Answer:\n\n" @@ -178,29 +247,38 @@ class PRHelpMessage: for section in relevant_sections: file = section.get('file_name').strip().removesuffix('.md') if str(section['relevant_section_header_string']).strip(): - markdown_header = self.format_markdown_header(section['relevant_section_header_string']) + markdown_header = self.format_markdown_header( + section['relevant_section_header_string'] + ) answer_str += f"> - {base_path}{file}#{markdown_header}\n" else: answer_str += f"> - {base_path}{file}\n" - # publish the answer if get_settings().config.publish_output: self.git_provider.publish_comment(answer_str) else: get_logger().info(f"Answer:\n{answer_str}") else: - if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"): + if not isinstance( + self.git_provider, BitbucketServerProvider + ) and not self.git_provider.is_supported("gfm_markdown"): self.git_provider.publish_comment( - "The `Help` tool requires gfm markdown, which is not supported by your code platform.") + "The `Help` tool requires gfm markdown, which is not supported by your code platform." + ) return get_logger().info('Getting PR Help Message...') - relevant_configs = {'pr_help': dict(get_settings().pr_help), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_help': dict(get_settings().pr_help), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) pr_comment = "## PR Agent Walkthrough 🤖\n\n" - pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more.""" + pr_comment += ( + "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more." + "" + ) pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n" base_path = "https://pr-agent-docs.codium.ai/tools" @@ -211,32 +289,58 @@ class PRHelpMessage: tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)") tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎") tool_names.append(f"[TEST]({base_path}/test/) 💎") - tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎") + tool_names.append( + f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎" + ) tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎") tool_names.append(f"[ASK]({base_path}/ask/)") tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)") - tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎") + tool_names.append( + f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎" + ) tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎") tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎") tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎") descriptions = [] - descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels") - descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more") + descriptions.append( + "Generates PR description - title, type, summary, code walkthrough and labels" + ) + descriptions.append( + "Adjustable feedback about the PR, possible issues, security concerns, review effort and more" + ) descriptions.append("Code suggestions for improving the PR") descriptions.append("Automatically updates the changelog") - descriptions.append("Generates documentation to methods/functions/classes that changed in the PR") - descriptions.append("Generates unit tests for a specific component, based on the PR code change") - descriptions.append("Code suggestions for a specific component that changed in the PR") - descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component") + descriptions.append( + "Generates documentation to methods/functions/classes that changed in the PR" + ) + descriptions.append( + "Generates unit tests for a specific component, based on the PR code change" + ) + descriptions.append( + "Code suggestions for a specific component that changed in the PR" + ) + descriptions.append( + "Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component" + ) descriptions.append("Answering free-text questions about the PR") - descriptions.append("Automatically retrieves and presents similar issues") - descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user") - descriptions.append("Generates feedback and analysis for a failed CI job") - descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user") - descriptions.append("Generates implementation code from review suggestions") + descriptions.append( + "Automatically retrieves and presents similar issues" + ) + descriptions.append( + "Generates custom labels for the PR, based on specific guidelines defined by the user" + ) + descriptions.append( + "Generates feedback and analysis for a failed CI job" + ) + descriptions.append( + "Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user" + ) + descriptions.append( + "Generates implementation code from review suggestions" + ) - commands =[] + commands = [] commands.append("`/describe`") commands.append("`/review`") commands.append("`/improve`") @@ -271,7 +375,9 @@ class PRHelpMessage: checkbox_list.append("[*]") checkbox_list.append("[*]") - if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False): + if isinstance( + self.git_provider, GithubProvider + ) and not get_settings().config.get('disable_checkboxes', False): pr_comment += f"
    """ - for filename, file_changes_title, file_change_description in list_tuples: + for ( + filename, + file_changes_title, + file_change_description, + ) in list_tuples: filename = filename.replace("'", "`").rstrip() filename_publish = filename.split("/")[-1] if file_changes_title and file_changes_title.strip() != "...": file_changes_title_code = f"{file_changes_title}" - file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip() + file_changes_title_code_br = insert_br_after_x_chars( + file_changes_title_code, x=(delta - 5) + ).strip() if len(file_changes_title_code_br) < (delta - 5): - file_changes_title_code_br += "  " * ((delta - 5) - len(file_changes_title_code_br)) + file_changes_title_code_br += "  " * ( + (delta - 5) - len(file_changes_title_code_br) + ) filename_publish = f"{filename_publish}
    {file_changes_title_code_br}
    " else: filename_publish = f"{filename_publish}" @@ -679,15 +895,30 @@ class PRDescription: link = "" if hasattr(self.git_provider, 'get_line_link'): filename = filename.strip() - link = self.git_provider.get_line_link(filename, relevant_line_start=-1) - if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()): - get_logger().warning(f"Error getting line link for '{filename}'") + link = self.git_provider.get_line_link( + filename, relevant_line_start=-1 + ) + if (not link or not diff_plus_minus) and ( + 'additional files' not in filename.lower() + ): + get_logger().warning( + f"Error getting line link for '{filename}'" + ) continue # Add file data to the PR body - file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5)) - pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename, - filename_publish, link, pr_body) + file_change_description_br = insert_br_after_x_chars( + file_change_description, x=(delta - 5) + ) + pr_body = self.add_file_data( + delta_nbsp, + diff_plus_minus, + file_change_description_br, + filename, + filename_publish, + link, + pr_body, + ) # Close the collapsible file list if use_collapsible_file_list: @@ -697,13 +928,22 @@ class PRDescription: pr_body += """
    """ except Exception as e: - get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}") + get_logger().error( + f"Error processing pr files to markdown {self.pr_id}: {str(e)}" + ) pass return pr_body, pr_comments - def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link, - pr_body) -> str: - + def add_file_data( + self, + delta_nbsp, + diff_plus_minus, + file_change_description_br, + filename, + filename_publish, + link, + pr_body, + ) -> str: if not file_change_description_br: pr_body += f"""
    " for i in range(len(tool_names)): pr_comment += f"\n\n\n" diff --git a/apps/utils/pr_agent/tools/pr_line_questions.py b/apps/utils/pr_agent/tools/pr_line_questions.py index 5067be1..60d330a 100644 --- a/apps/utils/pr_agent/tools/pr_line_questions.py +++ b/apps/utils/pr_agent/tools/pr_line_questions.py @@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler -from utils.pr_agent.algo.git_patch_processing import ( - extract_hunk_lines_from_patch) +from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch from utils.pr_agent.algo.pr_processing import retry_with_fallback_models from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.utils import ModelType @@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger class PR_LineQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + args=None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): self.question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( @@ -34,10 +38,12 @@ class PR_LineQuestions: "full_hunk": "", "selected_lines": "", } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_line_questions_prompt.system, - get_settings().pr_line_questions_prompt.user) + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_line_questions_prompt.system, + get_settings().pr_line_questions_prompt.user, + ) self.patches_diff = None self.prediction = None @@ -48,7 +54,6 @@ class PR_LineQuestions: question_str = "" return question_str - async def run(self): get_logger().info('Answering a PR lines question...') # if get_settings().config.publish_output: @@ -62,22 +67,27 @@ class PR_LineQuestions: file_name = get_settings().get('file_name', '') comment_id = get_settings().get('comment_id', '') if ask_diff: - self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff, - file_name, - line_start=line_start, - line_end=line_end, - side=side - ) + self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch( + ask_diff, file_name, line_start=line_start, line_end=line_end, side=side + ) else: diff_files = self.git_provider.get_diff_files() for file in diff_files: if file.filename == file_name: - self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename, - line_start=line_start, - line_end=line_end, - side=side) + ( + self.patch_with_lines, + self.selected_lines, + ) = extract_hunk_lines_from_patch( + file.patch, + file.filename, + line_start=line_start, + line_end=line_end, + side=side, + ) if self.patch_with_lines: - model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK) + model_answer = await retry_with_fallback_models( + self._get_prediction, model_type=ModelType.WEAK + ) # sanitize the answer so that no line will start with "/" model_answer_sanitized = model_answer.strip().replace("\n/", "\n /") if model_answer_sanitized.startswith("/"): @@ -85,7 +95,9 @@ class PR_LineQuestions: get_logger().info('Preparing answer...') if comment_id: - self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized) + self.git_provider.reply_to_comment_from_comment_id( + comment_id, model_answer_sanitized + ) else: self.git_provider.publish_comment(model_answer_sanitized) @@ -96,8 +108,12 @@ class PR_LineQuestions: variables["full_hunk"] = self.patch_with_lines # update diff variables["selected_lines"] = self.selected_lines environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_line_questions_prompt.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_line_questions_prompt.user + ).render(variables) if get_settings().config.verbosity_level >= 2: # get_logger().info(f"\nSystem prompt:\n{system_prompt}") # get_logger().info(f"\nUser prompt:\n{user_prompt}") @@ -105,5 +121,9 @@ class PR_LineQuestions: print(f"\nUser prompt:\n{user_prompt}") response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + ) return response diff --git a/apps/utils/pr_agent/tools/pr_questions.py b/apps/utils/pr_agent/tools/pr_questions.py index a1dae7b..081e03b 100644 --- a/apps/utils/pr_agent/tools/pr_questions.py +++ b/apps/utils/pr_agent/tools/pr_questions.py @@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + args=None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): question_str = self.parse_args(args) self.pr_url = pr_url self.git_provider = get_git_provider()(pr_url) @@ -36,10 +41,12 @@ class PRQuestions: "questions": self.question_str, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_questions_prompt.system, - get_settings().pr_questions_prompt.user) + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_questions_prompt.system, + get_settings().pr_questions_prompt.user, + ) self.patches_diff = None self.prediction = None @@ -52,8 +59,10 @@ class PRQuestions: async def run(self): get_logger().info(f'Answering a PR question about the PR {self.pr_url} ') - relevant_configs = {'pr_questions': dict(get_settings().pr_questions), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_questions': dict(get_settings().pr_questions), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) if get_settings().config.publish_output: self.git_provider.publish_comment("思考回答中...", is_temporary=True) @@ -63,12 +72,17 @@ class PRQuestions: if img_path: get_logger().debug(f"Image path identified", artifact=img_path) - await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) + await retry_with_fallback_models( + self._prepare_prediction, model_type=ModelType.WEAK + ) pr_comment = self._prepare_pr_answer() get_logger().debug(f"PR output", artifact=pr_comment) - if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text: + if ( + self.git_provider.is_supported("gfm_markdown") + and get_settings().pr_questions.enable_help_text + ): pr_comment += "
    \n\n
    💡 Tool usage guide:
    \n\n" pr_comment += HelpMessage.get_ask_usage_guide() pr_comment += "\n
    \n" @@ -85,7 +99,9 @@ class PRQuestions: # /ask question ... > ![image](img_path) img_path = self.question_str.split('![image]')[1].strip().strip('()') self.vars['img_path'] = img_path - elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link + elif 'https://' in self.question_str and ( + '.png' in self.question_str or 'jpg' in self.question_str + ): # direct image link # include https:// in the image path img_path = 'https://' + self.question_str.split('https://')[1] self.vars['img_path'] = img_path @@ -104,16 +120,28 @@ class PRQuestions: variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_questions_prompt.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_questions_prompt.user + ).render(variables) if 'img_path' in variables: img_path = self.vars['img_path'] - response, finish_reason = await (self.ai_handler.chat_completion - (model=model, temperature=get_settings().config.temperature, - system=system_prompt, user=user_prompt, img_path=img_path)) + response, finish_reason = await self.ai_handler.chat_completion( + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + img_path=img_path, + ) else: response, finish_reason = await self.ai_handler.chat_completion( - model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) + model=model, + temperature=get_settings().config.temperature, + system=system_prompt, + user=user_prompt, + ) return response def _prepare_pr_answer(self) -> str: @@ -123,9 +151,13 @@ class PRQuestions: if model_answer_sanitized.startswith("/"): model_answer_sanitized = " " + model_answer_sanitized if model_answer_sanitized != model_answer: - get_logger().debug(f"Sanitized model answer", - artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized}) - + get_logger().debug( + f"Sanitized model answer", + artifact={ + "model_answer": model_answer, + "sanitized_answer": model_answer_sanitized, + }, + ) answer_str = f"### **Ask**❓\n{self.question_str}\n\n" answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n" diff --git a/apps/utils/pr_agent/tools/pr_reviewer.py b/apps/utils/pr_agent/tools/pr_reviewer.py index e21628f..ef4b7dd 100644 --- a/apps/utils/pr_agent/tools/pr_reviewer.py +++ b/apps/utils/pr_agent/tools/pr_reviewer.py @@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler -from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, - get_pr_diff, - retry_with_fallback_models) +from utils.pr_agent.algo.pr_processing import ( + add_ai_metadata_to_diff_files, + get_pr_diff, + retry_with_fallback_models, +) from utils.pr_agent.algo.token_handler import TokenHandler -from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader, - convert_to_markdown_v2, github_action_output, - load_yaml, show_relevant_configurations) +from utils.pr_agent.algo.utils import ( + ModelType, + PRReviewHeader, + convert_to_markdown_v2, + github_action_output, + load_yaml, + show_relevant_configurations, +) from utils.pr_agent.config_loader import get_settings -from utils.pr_agent.git_providers import (get_git_provider_with_context) -from utils.pr_agent.git_providers.git_provider import (IncrementalPR, - get_main_pr_language) +from utils.pr_agent.git_providers import get_git_provider_with_context +from utils.pr_agent.git_providers.git_provider import ( + IncrementalPR, + get_main_pr_language, +) from utils.pr_agent.log import get_logger from utils.pr_agent.servers.help import HelpMessage -from utils.pr_agent.tools.ticket_pr_compliance_check import ( - extract_and_cache_pr_tickets) +from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets class PRReviewer: @@ -29,8 +37,14 @@ class PRReviewer: The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + is_answer: bool = False, + is_auto: bool = False, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. @@ -55,16 +69,23 @@ class PRReviewer: self.is_auto = is_auto if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): - raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") + raise Exception( + f"Answer mode is not supported for {get_settings().config.git_provider} for now" + ) self.ai_handler = ai_handler() self.ai_handler.main_pr_language = self.main_language self.patches_diff = None self.prediction = None answer_str, question_str = self._get_user_answers() - self.pr_description, self.pr_description_files = ( - self.git_provider.get_pr_description(split_changes_walkthrough=True)) - if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and - get_settings().get("config.enable_ai_metadata", False)): + ( + self.pr_description, + self.pr_description_files, + ) = self.git_provider.get_pr_description(split_changes_walkthrough=True) + if ( + self.pr_description_files + and get_settings().get("config.is_auto_command", False) + and get_settings().get("config.enable_ai_metadata", False) + ): add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files) get_logger().debug(f"AI metadata added to the this command") else: @@ -89,9 +110,11 @@ class PRReviewer: "commit_messages_str": self.git_provider.get_commit_messages(), "custom_labels": "", "enable_custom_labels": get_settings().config.enable_custom_labels, - "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), + "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), "related_tickets": get_settings().get('related_tickets', []), - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), + 'duplicate_prompt_examples': get_settings().config.get( + 'duplicate_prompt_examples', False + ), "date": datetime.datetime.now().strftime('%Y-%m-%d'), } @@ -99,7 +122,7 @@ class PRReviewer: self.git_provider.pr, self.vars, get_settings().pr_review_prompt.system, - get_settings().pr_review_prompt.user + get_settings().pr_review_prompt.user, ) def parse_incremental(self, args: List[str]): @@ -117,7 +140,10 @@ class PRReviewer: get_logger().info(f"PR has no files: {self.pr_url}, skipping review") return None - if self.incremental.is_incremental and not self._can_run_incremental_review(): + if ( + self.incremental.is_incremental + and not self._can_run_incremental_review() + ): return None # if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve': @@ -126,27 +152,41 @@ class PRReviewer: # return None get_logger().info(f'Reviewing PR: {self.pr_url} ...') - relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_reviewer': dict(get_settings().pr_reviewer), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) # ticket extraction if exists await extract_and_cache_pr_tickets(self.git_provider, self.vars) - if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set: - get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files") + if ( + self.incremental.is_incremental + and hasattr(self.git_provider, "unreviewed_files_set") + and not self.git_provider.unreviewed_files_set + ): + get_logger().info( + f"Incremental review is enabled for {self.pr_url} but there are no new files" + ) previous_review_url = "" if hasattr(self.git_provider, "previous_review"): previous_review_url = self.git_provider.previous_review.html_url if get_settings().config.publish_output: - self.git_provider.publish_comment(f"Incremental Review Skipped\n" - f"No files were changed since the [previous PR Review]({previous_review_url})") + self.git_provider.publish_comment( + f"Incremental Review Skipped\n" + f"No files were changed since the [previous PR Review]({previous_review_url})" + ) return None - if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): + if get_settings().config.publish_output and not get_settings().config.get( + 'is_auto_command', False + ): self.git_provider.publish_comment("准备评审中...", is_temporary=True) - await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) + await retry_with_fallback_models( + self._prepare_prediction, model_type=ModelType.REGULAR + ) if not self.prediction: self.git_provider.remove_initial_comment() return None @@ -156,12 +196,19 @@ class PRReviewer: if get_settings().config.publish_output: # publish the review - if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental: - final_update_message = get_settings().pr_reviewer.final_update_message - self.git_provider.publish_persistent_comment(pr_review, - initial_header=f"{PRReviewHeader.REGULAR.value} 🔍", - update_header=True, - final_update_message=final_update_message, ) + if ( + get_settings().pr_reviewer.persistent_comment + and not self.incremental.is_incremental + ): + final_update_message = ( + get_settings().pr_reviewer.final_update_message + ) + self.git_provider.publish_persistent_comment( + pr_review, + initial_header=f"{PRReviewHeader.REGULAR.value} 🔍", + update_header=True, + final_update_message=final_update_message, + ) else: self.git_provider.publish_comment(pr_review) @@ -174,11 +221,13 @@ class PRReviewer: get_logger().error(f"Failed to review PR: {e}") async def _prepare_prediction(self, model: str) -> None: - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - model, - add_line_numbers_to_hunks=True, - disable_extra_lines=False,) + self.patches_diff = get_pr_diff( + self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=False, + ) if self.patches_diff: get_logger().debug(f"PR diff", diff=self.patches_diff) @@ -201,14 +250,18 @@ class PRReviewer: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_review_prompt.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_review_prompt.user + ).render(variables) response, finish_reason = await self.ai_handler.chat_completion( model=model, temperature=get_settings().config.temperature, system=system_prompt, - user=user_prompt + user=user_prompt, ) return response @@ -220,10 +273,20 @@ class PRReviewer: """ first_key = 'review' last_key = 'security_concerns' - data = load_yaml(self.prediction.strip(), - keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", - "relevant_file:", "relevant_line:", "suggestion:"], - first_key=first_key, last_key=last_key) + data = load_yaml( + self.prediction.strip(), + keys_fix_yaml=[ + "ticket_compliance_check", + "estimated_effort_to_review_[1-5]:", + "security_concerns:", + "key_issues_to_review:", + "relevant_file:", + "relevant_line:", + "suggestion:", + ], + first_key=first_key, + last_key=last_key, + ) github_action_output(data, 'review') # move data['review'] 'key_issues_to_review' key to the end of the dictionary @@ -234,24 +297,38 @@ class PRReviewer: incremental_review_markdown_text = None # Add incremental review section if self.incremental.is_incremental: - last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \ - f"{self.git_provider.incremental.first_new_commit_sha}" + last_commit_url = ( + f"{self.git_provider.get_pr_url()}/commits/" + f"{self.git_provider.incremental.first_new_commit_sha}" + ) incremental_review_markdown_text = f"Starting from commit {last_commit_url}" - markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"), - incremental_review_markdown_text, - git_provider=self.git_provider, - files=self.git_provider.get_diff_files()) + markdown_text = convert_to_markdown_v2( + data, + self.git_provider.is_supported("gfm_markdown"), + incremental_review_markdown_text, + git_provider=self.git_provider, + files=self.git_provider.get_diff_files(), + ) # Add help text if gfm_markdown is supported - if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text: + if ( + self.git_provider.is_supported("gfm_markdown") + and get_settings().pr_reviewer.enable_help_text + ): markdown_text += "
    \n\n
    💡 Tool usage guide:
    \n\n" markdown_text += HelpMessage.get_review_usage_guide() markdown_text += "\n
    \n" # Output the relevant configurations if enabled - if get_settings().get('config', {}).get('output_relevant_configurations', False): - markdown_text += show_relevant_configurations(relevant_section='pr_reviewer') + if ( + get_settings() + .get('config', {}) + .get('output_relevant_configurations', False) + ): + markdown_text += show_relevant_configurations( + relevant_section='pr_reviewer' + ) # Add custom labels from the review prediction (effort, security) self.set_review_labels(data) @@ -306,34 +383,50 @@ class PRReviewer: if comment: self.git_provider.remove_comment(comment) except Exception as e: - get_logger().exception(f"Failed to remove previous review comment, error: {e}") + get_logger().exception( + f"Failed to remove previous review comment, error: {e}" + ) def _can_run_incremental_review(self) -> bool: """Checks if we can run incremental review according the various configurations and previous review""" # checking if running is auto mode but there are no new commits if self.is_auto and not self.incremental.first_new_commit_sha: - get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits") + get_logger().info( + f"Incremental review is enabled for {self.pr_url} but there are no new commits" + ) return False if not hasattr(self.git_provider, "get_incremental_commits"): - get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}") + get_logger().info( + f"Incremental review is not supported for {get_settings().config.git_provider}" + ) return False # checking if there are enough commits to start the review num_new_commits = len(self.incremental.commits_range) - num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review + num_commits_threshold = ( + get_settings().pr_reviewer.minimal_commits_for_incremental_review + ) not_enough_commits = num_new_commits < num_commits_threshold # checking if the commits are not too recent to start the review recent_commits_threshold = datetime.datetime.now() - datetime.timedelta( minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review ) last_seen_commit_date = ( - self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None + self.incremental.last_seen_commit.commit.author.date + if self.incremental.last_seen_commit + else None ) all_commits_too_recent = ( - last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False + last_seen_commit_date > recent_commits_threshold + if self.incremental.last_seen_commit + else False ) # check all the thresholds or just one to start the review - condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all + condition = ( + any + if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review + else all + ) if condition((not_enough_commits, all_commits_too_recent)): get_logger().info( f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:" @@ -348,31 +441,55 @@ class PRReviewer: return if not get_settings().pr_reviewer.require_estimate_effort_to_review: - get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output + get_settings().pr_reviewer.enable_review_labels_effort = ( + False # we did not generate this output + ) if not get_settings().pr_reviewer.require_security_review: - get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output + get_settings().pr_reviewer.enable_review_labels_security = ( + False # we did not generate this output + ) - if (get_settings().pr_reviewer.enable_review_labels_security or - get_settings().pr_reviewer.enable_review_labels_effort): + if ( + get_settings().pr_reviewer.enable_review_labels_security + or get_settings().pr_reviewer.enable_review_labels_effort + ): try: review_labels = [] if get_settings().pr_reviewer.enable_review_labels_effort: - estimated_effort = data['review']['estimated_effort_to_review_[1-5]'] + estimated_effort = data['review'][ + 'estimated_effort_to_review_[1-5]' + ] estimated_effort_number = 0 if isinstance(estimated_effort, str): try: - estimated_effort_number = int(estimated_effort.split(',')[0]) + estimated_effort_number = int( + estimated_effort.split(',')[0] + ) except ValueError: - get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}") + get_logger().warning( + f"Invalid estimated_effort value: {estimated_effort}" + ) elif isinstance(estimated_effort, int): estimated_effort_number = estimated_effort else: - get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}") + get_logger().warning( + f"Unexpected type for estimated_effort: {type(estimated_effort)}" + ) if 1 <= estimated_effort_number <= 5: # 1, because ... - review_labels.append(f'Review effort {estimated_effort_number}/5') - if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review: - security_concerns = data['review']['security_concerns'] # yes, because ... - security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower() + review_labels.append( + f'Review effort {estimated_effort_number}/5' + ) + if ( + get_settings().pr_reviewer.enable_review_labels_security + and get_settings().pr_reviewer.require_security_review + ): + security_concerns = data['review'][ + 'security_concerns' + ] # yes, because ... + security_concerns_bool = ( + 'yes' in security_concerns.lower() + or 'true' in security_concerns.lower() + ) if security_concerns_bool: review_labels.append('Possible security concern') @@ -381,17 +498,26 @@ class PRReviewer: current_labels = [] get_logger().debug(f"Current labels:\n{current_labels}") if current_labels: - current_labels_filtered = [label for label in current_labels if - not label.lower().startswith('review effort') and not label.lower().startswith( - 'possible security concern')] + current_labels_filtered = [ + label + for label in current_labels + if not label.lower().startswith('review effort') + and not label.lower().startswith('possible security concern') + ] else: current_labels_filtered = [] new_labels = review_labels + current_labels_filtered - if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels): - get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}") + if (current_labels or review_labels) and sorted(new_labels) != sorted( + current_labels + ): + get_logger().info( + f"Setting review labels:\n{review_labels + current_labels_filtered}" + ) self.git_provider.publish_labels(new_labels) else: - get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}") + get_logger().info( + f"Review labels are already set:\n{review_labels + current_labels_filtered}" + ) except Exception as e: get_logger().error(f"Failed to set review labels, error: {e}") @@ -406,5 +532,7 @@ class PRReviewer: self.git_provider.publish_comment("自动批准 PR") else: get_logger().info("Auto-approval option is disabled") - self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. " - "你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)") + self.git_provider.publish_comment( + "PR-Agent 的自动批准选项已禁用. " + "你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)" + ) diff --git a/apps/utils/pr_agent/tools/pr_similar_issue.py b/apps/utils/pr_agent/tools/pr_similar_issue.py index 6f9ea20..a8d2300 100644 --- a/apps/utils/pr_agent/tools/pr_similar_issue.py +++ b/apps/utils/pr_agent/tools/pr_similar_issue.py @@ -24,12 +24,16 @@ class PRSimilarIssue: self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan self.issue_url = issue_url self.git_provider = get_git_provider()() - repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1]) + repo_name, issue_number = self.git_provider._parse_issue_url( + issue_url.split('=')[-1] + ) self.git_provider.repo = repo_name self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name) self.token_handler = TokenHandler() repo_obj = self.git_provider.repo_obj - repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-') + repo_name_for_index = self.repo_name_for_index = ( + repo_obj.full_name.lower().replace('/', '-').replace('_/', '-') + ) index_name = self.index_name = "codium-ai-pr-agent-issues" if get_settings().pr_similar_issue.vectordb == "pinecone": @@ -38,17 +42,30 @@ class PRSimilarIssue: import pinecone from pinecone_datasets import Dataset, DatasetMetadata except: - raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb") + raise Exception( + "Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb" + ) # assuming pinecone api key and environment are set in secrets file try: api_key = get_settings().pinecone.api_key environment = get_settings().pinecone.environment except Exception: if not self.cli_mode: - repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) - issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) - issue_main.create_comment("Please set pinecone api key and environment in secrets file") - raise Exception("Please set pinecone api key and environment in secrets file") + ( + repo_name, + original_issue_number, + ) = self.git_provider._parse_issue_url( + self.issue_url.split('=')[-1] + ) + issue_main = self.git_provider.repo_obj.get_issue( + original_issue_number + ) + issue_main.create_comment( + "Please set pinecone api key and environment in secrets file" + ) + raise Exception( + "Please set pinecone api key and environment in secrets file" + ) # check if index exists, and if repo is already indexed run_from_scratch = False @@ -69,7 +86,9 @@ class PRSimilarIssue: upsert = True else: pinecone_index = pinecone.Index(index_name=index_name) - res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict() + res = pinecone_index.fetch( + [f"example_issue_{repo_name_for_index}"] + ).to_dict() if res["vectors"]: upsert = False @@ -79,7 +98,9 @@ class PRSimilarIssue: get_logger().info('Getting issues...') issues = list(repo_obj.get_issues(state='all')) get_logger().info('Done') - self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert) + self._update_index_with_issues( + issues, repo_name_for_index, upsert=upsert + ) else: # update index if needed pinecone_index = pinecone.Index(index_name=index_name) issues_to_update = [] @@ -105,7 +126,9 @@ class PRSimilarIssue: if issues_to_update: get_logger().info(f'Updating index with {counter} new issues...') - self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True) + self._update_index_with_issues( + issues_to_update, repo_name_for_index, upsert=True + ) else: get_logger().info('No new issues to update') @@ -133,7 +156,12 @@ class PRSimilarIssue: ingest = True else: self.table = self.db[index_name] - res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list() + res = ( + self.table.search() + .limit(len(self.table)) + .where(f"id='example_issue_{repo_name_for_index}'") + .to_list() + ) get_logger().info("result: ", res) if res[0].get("vector"): ingest = False @@ -145,7 +173,9 @@ class PRSimilarIssue: issues = list(repo_obj.get_issues(state='all')) get_logger().info('Done') - self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest) + self._update_table_with_issues( + issues, repo_name_for_index, ingest=ingest + ) else: # update table if needed issues_to_update = [] issues_paginated_list = repo_obj.get_issues(state='all') @@ -156,7 +186,12 @@ class PRSimilarIssue: issue_str, comments, number = self._process_issue(issue) issue_key = f"issue_{number}" issue_id = issue_key + "." + "issue" - res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list() + res = ( + self.table.search() + .limit(len(self.table)) + .where(f"id='{issue_id}'") + .to_list() + ) is_new_issue = True for r in res: if r['metadata']['repo'] == repo_name_for_index: @@ -170,14 +205,17 @@ class PRSimilarIssue: if issues_to_update: get_logger().info(f'Updating index with {counter} new issues...') - self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True) + self._update_table_with_issues( + issues_to_update, repo_name_for_index, ingest=True + ) else: get_logger().info('No new issues to update') - async def run(self): get_logger().info('Getting issue...') - repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) + repo_name, original_issue_number = self.git_provider._parse_issue_url( + self.issue_url.split('=')[-1] + ) issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) issue_str, comments, number = self._process_issue(issue_main) openai.api_key = get_settings().openai.key @@ -193,10 +231,12 @@ class PRSimilarIssue: if get_settings().pr_similar_issue.vectordb == "pinecone": pinecone_index = pinecone.Index(index_name=self.index_name) - res = pinecone_index.query(embeds[0], - top_k=5, - filter={"repo": self.repo_name_for_index}, - include_metadata=True).to_dict() + res = pinecone_index.query( + embeds[0], + top_k=5, + filter={"repo": self.repo_name_for_index}, + include_metadata=True, + ).to_dict() for r in res['matches']: # skip example issue @@ -214,14 +254,20 @@ class PRSimilarIssue: if issue_number not in relevant_issues_number_list: relevant_issues_number_list.append(issue_number) if 'comment' in r["id"]: - relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1])) + relevant_comment_number_list.append( + int(r["id"].split('.')[1].split('_')[-1]) + ) else: relevant_comment_number_list.append(-1) score_list.append(str("{:.2f}".format(r['score']))) get_logger().info('Done') elif get_settings().pr_similar_issue.vectordb == "lancedb": - res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list() + res = ( + self.table.search(embeds[0]) + .where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True) + .to_list() + ) for r in res: # skip example issue @@ -240,10 +286,12 @@ class PRSimilarIssue: relevant_issues_number_list.append(issue_number) if 'comment' in r["id"]: - relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1])) + relevant_comment_number_list.append( + int(r["id"].split('.')[1].split('_')[-1]) + ) else: relevant_comment_number_list.append(-1) - score_list.append(str("{:.2f}".format(1-r['_distance']))) + score_list.append(str("{:.2f}".format(1 - r['_distance']))) get_logger().info('Done') get_logger().info('Publishing response...') @@ -254,8 +302,12 @@ class PRSimilarIssue: title = issue.title url = issue.html_url if relevant_comment_number_list[i] != -1: - url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url - similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n" + url = list(issue.get_comments())[ + relevant_comment_number_list[i] + ].html_url + similar_issues_str += ( + f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n" + ) if get_settings().config.publish_output: response = issue_main.create_comment(similar_issues_str) get_logger().info(similar_issues_str) @@ -278,7 +330,7 @@ class PRSimilarIssue: example_issue_record = Record( id=f"example_issue_{repo_name_for_index}", text="example_issue", - metadata=Metadata(repo=repo_name_for_index) + metadata=Metadata(repo=repo_name_for_index), ) corpus.append(example_issue_record) @@ -298,15 +350,20 @@ class PRSimilarIssue: issue_key = f"issue_{number}" username = issue.user.login created_at = str(issue.created_at) - if len(issue_str) < 8000 or \ - self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first + if len(issue_str) < 8000 or self.token_handler.count_tokens( + issue_str + ) < get_max_tokens( + MODEL + ): # fast reject first issue_record = Record( id=issue_key + "." + "issue", text=issue_str, - metadata=Metadata(repo=repo_name_for_index, - username=username, - created_at=created_at, - level=IssueLevel.ISSUE) + metadata=Metadata( + repo=repo_name_for_index, + username=username, + created_at=created_at, + level=IssueLevel.ISSUE, + ), ) corpus.append(issue_record) if comments: @@ -316,15 +373,20 @@ class PRSimilarIssue: if num_words_comment < 10 or not isinstance(comment_body, str): continue - if len(comment_body) < 8000 or \ - self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]: + if ( + len(comment_body) < 8000 + or self.token_handler.count_tokens(comment_body) + < MAX_TOKENS[MODEL] + ): comment_record = Record( id=issue_key + ".comment_" + str(j + 1), text=comment_body, - metadata=Metadata(repo=repo_name_for_index, - username=username, # use issue username for all comments - created_at=created_at, - level=IssueLevel.COMMENT) + metadata=Metadata( + repo=repo_name_for_index, + username=username, # use issue username for all comments + created_at=created_at, + level=IssueLevel.COMMENT, + ), ) corpus.append(comment_record) df = pd.DataFrame(corpus.dict()["documents"]) @@ -355,7 +417,9 @@ class PRSimilarIssue: environment = get_settings().pinecone.environment if not upsert: get_logger().info('Creating index from scratch...') - ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment) + ds.to_pinecone_index( + self.index_name, api_key=api_key, environment=environment + ) time.sleep(15) # wait for pinecone to finalize indexing before querying else: get_logger().info('Upserting index...') @@ -374,7 +438,7 @@ class PRSimilarIssue: example_issue_record = Record( id=f"example_issue_{repo_name_for_index}", text="example_issue", - metadata=Metadata(repo=repo_name_for_index) + metadata=Metadata(repo=repo_name_for_index), ) corpus.append(example_issue_record) @@ -394,15 +458,20 @@ class PRSimilarIssue: issue_key = f"issue_{number}" username = issue.user.login created_at = str(issue.created_at) - if len(issue_str) < 8000 or \ - self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first + if len(issue_str) < 8000 or self.token_handler.count_tokens( + issue_str + ) < get_max_tokens( + MODEL + ): # fast reject first issue_record = Record( id=issue_key + "." + "issue", text=issue_str, - metadata=Metadata(repo=repo_name_for_index, - username=username, - created_at=created_at, - level=IssueLevel.ISSUE) + metadata=Metadata( + repo=repo_name_for_index, + username=username, + created_at=created_at, + level=IssueLevel.ISSUE, + ), ) corpus.append(issue_record) if comments: @@ -412,15 +481,20 @@ class PRSimilarIssue: if num_words_comment < 10 or not isinstance(comment_body, str): continue - if len(comment_body) < 8000 or \ - self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]: + if ( + len(comment_body) < 8000 + or self.token_handler.count_tokens(comment_body) + < MAX_TOKENS[MODEL] + ): comment_record = Record( id=issue_key + ".comment_" + str(j + 1), text=comment_body, - metadata=Metadata(repo=repo_name_for_index, - username=username, # use issue username for all comments - created_at=created_at, - level=IssueLevel.COMMENT) + metadata=Metadata( + repo=repo_name_for_index, + username=username, # use issue username for all comments + created_at=created_at, + level=IssueLevel.COMMENT, + ), ) corpus.append(comment_record) df = pd.DataFrame(corpus.dict()["documents"]) @@ -446,7 +520,9 @@ class PRSimilarIssue: if not ingest: get_logger().info('Creating table from scratch...') - self.table = self.db.create_table(self.index_name, data=df, mode="overwrite") + self.table = self.db.create_table( + self.index_name, data=df, mode="overwrite" + ) time.sleep(15) else: get_logger().info('Ingesting in Table...') diff --git a/apps/utils/pr_agent/tools/pr_update_changelog.py b/apps/utils/pr_agent/tools/pr_update_changelog.py index 56c9eca..f4a7b24 100644 --- a/apps/utils/pr_agent/tools/pr_update_changelog.py +++ b/apps/utils/pr_agent/tools/pr_update_changelog.py @@ -20,13 +20,20 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): - + def __init__( + self, + pr_url: str, + cli_mode=False, + args=None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes + self.commit_changelog = ( + get_settings().pr_update_changelog.push_changelog_changes + ) self._get_changelog_file() # self.changelog_file_str self.ai_handler = ai_handler() @@ -47,15 +54,19 @@ class PRUpdateChangelog: "extra_instructions": get_settings().pr_update_changelog.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_update_changelog_prompt.system, - get_settings().pr_update_changelog_prompt.user) + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_update_changelog_prompt.system, + get_settings().pr_update_changelog_prompt.user, + ) async def run(self): get_logger().info('Updating the changelog...') - relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog), - 'config': dict(get_settings().config)} + relevant_configs = { + 'pr_update_changelog': dict(get_settings().pr_update_changelog), + 'config': dict(get_settings().config), + } get_logger().debug("Relevant configs", artifacts=relevant_configs) # currently only GitHub is supported for pushing changelog changes @@ -74,13 +85,21 @@ class PRUpdateChangelog: if get_settings().config.publish_output: self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True) - await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) + await retry_with_fallback_models( + self._prepare_prediction, model_type=ModelType.WEAK + ) new_file_content, answer = self._prepare_changelog_update() # Output the relevant configurations if enabled - if get_settings().get('config', {}).get('output_relevant_configurations', False): - answer += show_relevant_configurations(relevant_section='pr_update_changelog') + if ( + get_settings() + .get('config', {}) + .get('output_relevant_configurations', False) + ): + answer += show_relevant_configurations( + relevant_section='pr_update_changelog' + ) get_logger().debug(f"PR output", artifact=answer) @@ -89,7 +108,9 @@ class PRUpdateChangelog: if self.commit_changelog: self._push_changelog_update(new_file_content, answer) else: - self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}") + self.git_provider.publish_comment( + f"**Changelog updates:** 🔄\n\n{answer}" + ) async def _prepare_prediction(self, model: str): self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) @@ -106,10 +127,18 @@ class PRUpdateChangelog: if get_settings().pr_update_changelog.add_pr_link: variables["pr_link"] = self.git_provider.get_pr_url() environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables) + system_prompt = environment.from_string( + get_settings().pr_update_changelog_prompt.system + ).render(variables) + user_prompt = environment.from_string( + get_settings().pr_update_changelog_prompt.user + ).render(variables) response, finish_reason = await self.ai_handler.chat_completion( - model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature) + model=model, + system=system_prompt, + user=user_prompt, + temperature=get_settings().config.temperature, + ) # post-process the response response = response.strip() @@ -134,8 +163,10 @@ class PRUpdateChangelog: new_file_content = answer if not self.commit_changelog: - answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ - "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n" + answer += ( + "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" + "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n" + ) return new_file_content, answer @@ -163,8 +194,7 @@ class PRUpdateChangelog: self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}") def _get_default_changelog(self): - example_changelog = \ -""" + example_changelog = """ Example: ## diff --git a/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py b/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py index 387f428..efc7d3a 100644 --- a/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py +++ b/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py @@ -7,14 +7,15 @@ from utils.pr_agent.log import get_logger # Compile the regex pattern once, outside the function GITHUB_TICKET_PATTERN = re.compile( - r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)' + r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)' ) + def find_jira_tickets(text): # Regular expression patterns for JIRA tickets patterns = [ r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123) - r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket + r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket ] tickets = set() @@ -32,7 +33,9 @@ def find_jira_tickets(text): return list(tickets) -def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'): +def extract_ticket_links_from_pr_description( + pr_description, repo_path, base_url_html='https://github.com' +): """ Extract all ticket links from PR description """ @@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url github_tickets.add(match[0]) elif match[1]: # Shorthand notation match: owner/repo#issue_number owner, repo, issue_number = match[2], match[3], match[4] - github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}') + github_tickets.add( + f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}' + ) else: # #123 format issue_number = match[5][1:] # remove # if issue_number.isdigit() and len(issue_number) < 5 and repo_path: - github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}') + github_tickets.add( + f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}' + ) if len(github_tickets) > 3: - get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}") + get_logger().info( + f"Too many tickets found in PR description: {len(github_tickets)}" + ) # Limit the number of tickets to 3 github_tickets = set(list(github_tickets)[:3]) except Exception as e: - get_logger().error(f"Error extracting tickets error= {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error extracting tickets error= {e}", + artifact={"traceback": traceback.format_exc()}, + ) return list(github_tickets) @@ -68,19 +79,26 @@ async def extract_tickets(git_provider): try: if isinstance(git_provider, GithubProvider): user_description = git_provider.get_user_description() - tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html) + tickets = extract_ticket_links_from_pr_description( + user_description, git_provider.repo, git_provider.base_url_html + ) tickets_content = [] if tickets: - for ticket in tickets: - repo_name, original_issue_number = git_provider._parse_issue_url(ticket) + repo_name, original_issue_number = git_provider._parse_issue_url( + ticket + ) try: - issue_main = git_provider.repo_obj.get_issue(original_issue_number) + issue_main = git_provider.repo_obj.get_issue( + original_issue_number + ) except Exception as e: - get_logger().error(f"Error getting main issue: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error getting main issue: {e}", + artifact={"traceback": traceback.format_exc()}, + ) continue issue_body_str = issue_main.body or "" @@ -93,47 +111,66 @@ async def extract_tickets(git_provider): sub_issues = git_provider.fetch_sub_issues(ticket) for sub_issue_url in sub_issues: try: - sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url) - sub_issue = git_provider.repo_obj.get_issue(sub_issue_number) + ( + sub_repo, + sub_issue_number, + ) = git_provider._parse_issue_url(sub_issue_url) + sub_issue = git_provider.repo_obj.get_issue( + sub_issue_number + ) sub_body = sub_issue.body or "" if len(sub_body) > MAX_TICKET_CHARACTERS: sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..." - sub_issues_content.append({ - 'ticket_url': sub_issue_url, - 'title': sub_issue.title, - 'body': sub_body - }) + sub_issues_content.append( + { + 'ticket_url': sub_issue_url, + 'title': sub_issue.title, + 'body': sub_body, + } + ) except Exception as e: - get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}") + get_logger().warning( + f"Failed to fetch sub-issue content for {sub_issue_url}: {e}" + ) except Exception as e: - get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}") + get_logger().warning( + f"Failed to fetch sub-issues for {ticket}: {e}" + ) # Extract labels labels = [] try: for label in issue_main.labels: - labels.append(label.name if hasattr(label, 'name') else label) + labels.append( + label.name if hasattr(label, 'name') else label + ) except Exception as e: - get_logger().error(f"Error extracting labels error= {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error extracting labels error= {e}", + artifact={"traceback": traceback.format_exc()}, + ) - tickets_content.append({ - 'ticket_id': issue_main.number, - 'ticket_url': ticket, - 'title': issue_main.title, - 'body': issue_body_str, - 'labels': ", ".join(labels), - 'sub_issues': sub_issues_content # Store sub-issues content - }) + tickets_content.append( + { + 'ticket_id': issue_main.number, + 'ticket_url': ticket, + 'title': issue_main.title, + 'body': issue_body_str, + 'labels': ", ".join(labels), + 'sub_issues': sub_issues_content, # Store sub-issues content + } + ) return tickets_content except Exception as e: - get_logger().error(f"Error extracting tickets error= {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Error extracting tickets error= {e}", + artifact={"traceback": traceback.format_exc()}, + ) async def extract_and_cache_pr_tickets(git_provider, vars): @@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars): related_tickets.append(ticket) - get_logger().info("Extracted tickets and sub-issues from PR description", - artifact={"tickets": related_tickets}) + get_logger().info( + "Extracted tickets and sub-issues from PR description", + artifact={"tickets": related_tickets}, + ) vars['related_tickets'] = related_tickets get_settings().set('related_tickets', related_tickets) diff --git a/config.ini b/config.ini new file mode 100644 index 0000000..4c277de --- /dev/null +++ b/config.ini @@ -0,0 +1,13 @@ +[BASE] +; 是否开启debug模式 0或1 +DEBUG = 0 + +[DATABASE] +; 默认采用sqlite,线上需替换为pg +DEFAULT = pg +; postgres配置 +DB_NAME = pr_manager +DB_USER = admin +DB_PASSWORD = admin123456 +DB_HOST = 110.40.30.95 +DB_PORT = 5432 diff --git a/pr_manager/settings.py b/pr_manager/settings.py index d2a19e0..a4423ff 100644 --- a/pr_manager/settings.py +++ b/pr_manager/settings.py @@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/ import os import sys +import configparser from pathlib import Path # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent +CONFIG_NAME = BASE_DIR / "config.ini" + +# 加载配置文件: 开发可加载config.local.ini +_config = configparser.ConfigParser() +_config.read(CONFIG_NAME, encoding="utf-8") + sys.path.insert(0, os.path.join(BASE_DIR, "apps")) sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils")) @@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils")) SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76" # SECURITY WARNING: don't run with debug turned on in production! -DEBUG = False +DEBUG = bool(int(_config["BASE"].get("DEBUG", "1"))) ALLOWED_HOSTS = ["*"] @@ -44,7 +51,7 @@ INSTALLED_APPS = [ "django.contrib.messages", "django.contrib.staticfiles", "public", - "pr" + "pr", ] # 配置安全秘钥 @@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls" TEMPLATES = [ { "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": [BASE_DIR / 'templates'] - , + "DIRS": [BASE_DIR / 'templates'], "APP_DIRS": True, "OPTIONS": { "context_processors": [ @@ -89,12 +95,22 @@ WSGI_APPLICATION = "pr_manager.wsgi.application" # https://docs.djangoproject.com/en/5.1/ref/settings/#databases DATABASES = { - "default": { + "pg": { + "ENGINE": "django.db.backends.postgresql", + "NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"), + "USER": _config["DATABASE"].get("DB_USER", "admin"), + "PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"), + "HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"), + "PORT": int(_config["DATABASE"].get("DB_PORT", "5432")), + }, + "sqlite": { "ENGINE": "django.db.backends.sqlite3", "NAME": BASE_DIR / "db.sqlite3", - } + }, } +DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")] + # Password validation # https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators
    ToolDescriptionTrigger Interactively :gem:
    \n\n{tool_names[i]}{descriptions[i]}\n\n{checkbox_list[i]}\n