Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 56 additions & 28 deletions pr_chat/chat_ws_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,90 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from contextlib import asynccontextmanager
from langchain_community.chat_message_histories import DynamoDBChatMessageHistory
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableWithMessageHistory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
import traceback

class PatchedDynamoDBChatMessageHistory(DynamoDBChatMessageHistory):
@property
def key(self):
return {"id": self.session_id}


# FastAPI app
app = FastAPI()
llm = None
ANTHROPIC_API_KEY = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global ANTHROPIC_API_KEY
ssm = boto3.client('ssm', region_name='us-east-1')
ANTHROPIC_API_KEY = ssm.get_parameter(
Name="/prreview/ANTHROPIC_API_KEY",
WithDecryption=True
)['Parameter']['Value']
os.environ["ANTHROPIC_API_KEY"] = ANTHROPIC_API_KEY
os.environ["AWS_REGION"] = "us-east-1"
global llm
llm = ChatAnthropic(model="claude-3-7-sonnet-20250219", temperature=0.7)

# Parameter Store (SSM) for API key
ssm = boto3.client('ssm', region_name='us-east-1')
ANTHROPIC_API_KEY = ssm.get_parameter(
Name="/prreview/ANTHROPIC_API_KEY",
WithDecryption=True
)['Parameter']['Value']
os.environ["ANTHROPIC_API_KEY"] = ANTHROPIC_API_KEY
yield # app is now running

# Claude setup
llm = ChatAnthropic(model="claude-3-7-sonnet-20250219", temperature=0.7)
app = FastAPI(lifespan=lifespan)
session = boto3.Session(region_name="us-east-1")

# Prompt
prompt = ChatPromptTemplate.from_messages([
("system", "You're a helpful assistant."),
("human", "{input}")
])

# Redis config
redis_host = os.environ.get("REDIS_HOST", "localhost")
redis_port = int(os.environ.get("REDIS_PORT", "6379"))

# Chain with memory
# DDB for chat history
def get_history(session_id: str):
return RedisChatMessageHistory(session_id=session_id, url=f"redis://{redis_host}:{redis_port}")
print(f"[DEBUG] Using session_id={session_id} (type: {type(session_id)})")
return DynamoDBChatMessageHistory(
table_name="ChatMemory",
session_id=session_id,
key={"id": session_id}
)

chat_chain = RunnableWithMessageHistory(
prompt | llm | StrOutputParser(),
get_session_history=get_history,
input_messages_key="input",
history_messages_key="messages"
)
def get_chat_chain():
if llm is None:
raise RuntimeError("LLM not initialized yet")
return RunnableWithMessageHistory(
prompt | llm | StrOutputParser(),
get_session_history=get_history,
input_messages_key="input",
history_messages_key="messages"
)

# WebSocket chat endpoint
@app.websocket("/ws/chat/{user_id}/{session_id}")
async def websocket_chat(websocket: WebSocket, user_id: str, session_id: str):
await websocket.accept()
try:
if ANTHROPIC_API_KEY is None:
print("🔑 ANTHROPIC_API_KEY not set. Cannot connect to Claude.")
else:
print(f"🔑 ANTHROPIC_API_KEY length is: {len(ANTHROPIC_API_KEY)}. Connected to Claude.")
while True:
message = await websocket.receive_text()
print(f"Received message from {user_id}/{session_id}: {message}")

response = chat_chain.invoke(
{"input": message},
config={"configurable": {"session_id": session_id}}
)
try:
response = get_chat_chain().invoke(
{"input": message},
config={"configurable": {"session_id": session_id}}
)
await websocket.send_text(response)

await websocket.send_text(response)
except Exception as e:
print("🔥 LLM invocation failed:")
traceback.print_exc()
await websocket.send_text(f"[Server Error] {str(e)}")

except WebSocketDisconnect:
print(f"Client disconnected: {user_id}/{session_id}")
Expand Down
13 changes: 1 addition & 12 deletions pr_chat/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,10 @@ services:
build:
context: .
dockerfile: Dockerfile
container_name: chat-api
container_name: chat-ws-api
ports:
- "8080:8080"
environment:
REDIS_HOST: redis
REDIS_PORT: 6379
AWS_REGION: us-east-1
volumes:
- ~/.aws:/root/.aws:ro
depends_on:
- redis

redis:
image: redis:7
container_name: chat-redis
ports:
- "6379:6379"
restart: always
252 changes: 252 additions & 0 deletions pr_chat/git_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
import boto3
import json
import time
import requests
from github import Github
from github import Auth
from datetime import datetime, timezone
from dateutil import tz
from zoneinfo import ZoneInfo
from pathlib import Path

owner = "TheAlgorithms"
repo = "Python"
owner = "public-apis"
repo = "public-apis"
pull_number = 25653 # Replace with the desired PR number
url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pull_number}"
pulls_url = f"https://api.github.com/repos/{owner}/{repo}/pulls"
# vinta_pulls = "https://api.github.com/vinta/awesome-python/pulls" # Example for a different repo
SUPPORTED_EXTENSIONS = {'.py', '.js', '.ts', '.java', '.cs', '.cpp', '.c', '.go', '.rb'}

# Create an SSM client
ssm = boto3.client('ssm')
def get_parameter(name):
"""Fetch parameter value from Parameter Store"""
response = ssm.get_parameter(
Name=name,
WithDecryption=True
)
return response['Parameter']['Value']

# Load secrets from AWS at cold start
GIT_API_KEY = get_parameter("/prreview/GIT_API_KEY")
if GIT_API_KEY is None:
raise ValueError("GIT_API_KEY was not found in the parameter store.")

headers = {
"Accept": "application/vnd.github.v3+json",
# Optional: Add token for higher rate limits
# "Authorization": "Bearer YOUR_TOKEN"
}

def get_pr_details():
response = requests.get(url, headers=headers)
if response.status_code == 200:
pr_data = response.json()
print("PR Title:", pr_data["title"])
print("Source Branch:", pr_data["head"]["ref"])
print("Target Branch:", pr_data["base"]["ref"])
print("Diff URL:", pr_data["diff_url"])
else:
print(f"Error: {response.status_code}, {response.json().get('message')}")

def get_pr_diff(repo,pr_number):
# Headers for diff request
diff_headers = {
"Accept": "application/vnd.github.v3.diff",
}
token = GIT_API_KEY
# print(f"Using GitHub API Key: {token}")
if token:
headers["Authorization"] = f"token {token}"

# Construct the diff URL
url = f"https://github.com/{repo}/pull/{pr_number}.diff"
# Get the diff for this PR
# print(f"Fetching diff for PR #{pr_number} in {repo}...")
response = requests.get(url, headers=diff_headers)
if response.status_code == 200:
diff = response.text
#print(f'Diff: {diff}')
return diff
else:
print(f"Error: {response.status_code}")

def get_supported_diffs(repo, pr_number):
url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/files"
headers = {
"Authorization": f"token {GIT_API_KEY}",
"Accept": "application/vnd.github.v3+json"
}
response = requests.get(url, headers=headers)
response.raise_for_status()
all_files = response.json()
#print(f"File: {all_files[0]}")


# Check if the response contains a list of files

# Keep only files with a supported extension
supported_diffs = [
file for file in all_files
if os.path.splitext(file['filename'])[1] in SUPPORTED_EXTENSIONS and 'patch' in file
]
if len(all_files) != len(supported_diffs):
print(f"Found {len(all_files) - len(supported_diffs)} out of {len(all_files)} filetypes in diffs which are not supported for PR #{pr_number} in {repo}.")

return supported_diffs

def get_pr_files(repo, pr_number):
url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/files"
# print(f"Using GitHub API Key: {token}")
headers = {
"Authorization": f"Bearer {GIT_API_KEY}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an error for bad responses

files = response.json()
for file in files:
filename = file.get("filename")
status = file.get("status") # e.g. 'added', 'modified', 'removed'
print(f"{status.upper()}: {filename}")

return files


def get_pull_requests(state='open'):
"""
Fetch pull requests from a GitHub repository.

Args:
owner (str): Repository owner (e.g., 'octocat')
repo (str): Repository name (e.g., 'hello-world')
token (str, optional): GitHub Personal Access Token for authentication
state (str): State of PRs to fetch ('open', 'closed', or 'all')

Returns:
list: List of pull requests
"""
params = {
"state": state,
"per_page": 5 # Maximum number of PRs per page
}

pull_requests = []
page = 1
print(f"Fetching {state} pull requests from {pulls_url}...")
#while True:
params["page"] = page
response = requests.get(pulls_url, headers=headers, params=params)

if response.status_code != 200:
print(f"Error: {response.status_code} - {response.json().get('message', 'Unknown error')}")
return

prs = response.json()
# if not prs: # No more PRs to fetch
# break

pull_requests.extend(prs)
#page += 1

return pull_requests

def print_pull_requests(prs):
"""
Print basic information about pull requests.

Args:
prs (list): List of pull requests
"""
for pr in prs:
print(f"State: {pr['state'].capitalize()}")
print(f"{pr['title']}")
user_name = pr['user']['login'] if 'user' in pr else 'Unknown User'
created_at = pr['created_at']
if created_at:
local_timezone = tz.tzlocal()
date_object = datetime.strptime(created_at, "%Y-%m-%dT%H:%M:%SZ")
local_time = date_object.astimezone(local_timezone)

created_at = local_time.strftime("%Y-%m-%d at %H:%M:%S")
else:
created_at = "Unknown Date"
print(f"PR #{pr['number']} opened by {user_name} on {created_at}")

print(f"URL: {pr['html_url']}")
print("-" * 50)

def post_review(repo, pr_number, decision, review):
headers = {
"Accept": "application/vnd.github.v3.diff",
"X-GitHub-Api-Version" : "2022-11-28"
}
if GIT_API_KEY:
headers["Authorization"] = f"token {GIT_API_KEY}"

payload = {
"body": f"{review}",
"event": f"{decision}",
"comments": [
{
"path": "path/to/file.py",
"position": 1,
"body": "Please change this line to improve readability."
}
]
}
url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
response = requests.post(url, headers=headers)
if response.status_code == 200:
review_data = response.json()
print("Review submitted successfully.")
print(f"Review ID: {review_data['id']}")
print(f"State: {review_data['state']}")
print(f"Submitted by: {review_data['user']['login']}")
print(f"HTML URL: {review_data['html_url']}")
else:
print(f"Failed to submit review: {response.status_code} - {response.text}")


def request_review(repo, pr_number, reviewer):
url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/requested_reviewers"
headers = {
"Authorization": f"token {GIT_API_KEY}",
"Accept": "application/vnd.github+json"
}
payload = {
"reviewers": [reviewer]
# Optional: "team_reviewers": ["team-slug"]
}

response = requests.post(url, headers=headers, json=payload)

if response.status_code == 201:
print("Reviewers requested successfully.")
else:
print(f"Failed to request reviewers: {response.status_code} - {response.text}")


if __name__ == "__main__":
# repo = "ississippi/pull-request-test-repo"
repo = "ississippi/pull-request-automated-review"
pr_number = 13
# get_pr_details()
# get_pr_files()
# get_pr_diff("ississippi/pull-request-test-repo", 16)
# Fetch pull requests
# prs = get_pull_requests()
# # Print results
# print(f"Found {len(prs)} pull requests:")
# print_pull_requests(prs)
# git_pr_list() : needs work
# get_pr_files(owner=owner,repo=repo,pr_number=pr_number)
# request_review(repo, 10, "ississippi")
# decision = "REQUEST_CHANGES"
# review = "This is close to perfect! Please address the suggested inline change."
# post_review(repo, 10, decision, review)
get_supported_diffs(repo, pr_number)
Binary file added pr_chat/static/favicon.ico
Binary file not shown.
Loading