Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 114 additions & 29 deletions backend/onyx/auth/email_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import smtplib
from datetime import datetime
from email.mime.image import MIMEImage
Expand All @@ -6,8 +7,21 @@
from email.utils import formatdate
from email.utils import make_msgid

import sendgrid # type: ignore
from sendgrid.helpers.mail import Attachment # type: ignore
from sendgrid.helpers.mail import Content
from sendgrid.helpers.mail import ContentId
from sendgrid.helpers.mail import Disposition
from sendgrid.helpers.mail import Email
from sendgrid.helpers.mail import FileContent
from sendgrid.helpers.mail import FileName
from sendgrid.helpers.mail import FileType
from sendgrid.helpers.mail import Mail
from sendgrid.helpers.mail import To

from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import SENDGRID_API_KEY
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
Expand All @@ -18,11 +32,12 @@
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.logger import setup_logger
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT

logger = setup_logger()

HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
Expand Down Expand Up @@ -176,6 +191,70 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")

if SENDGRID_API_KEY:
send_email_with_sendgrid(
user_email, subject, html_body, text_body, mail_from, inline_png
)
return

send_email_with_smtplib(
user_email, subject, html_body, text_body, mail_from, inline_png
)


def send_email_with_sendgrid(
user_email: str,
subject: str,
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
from_email = Email(mail_from) if mail_from else Email("noreply@onyx.app")
to_email = To(user_email)

mail = Mail(
from_email=from_email,
to_emails=to_email,
subject=subject,
plain_text_content=Content("text/plain", text_body),
)

# Add HTML content
mail.add_content(Content("text/html", html_body))

if inline_png:
image_name, image_data = inline_png

# Create attachment
encoded_image = base64.b64encode(image_data).decode()
attachment = Attachment()
attachment.file_content = FileContent(encoded_image)
attachment.file_name = FileName(image_name)
attachment.file_type = FileType("image/png")
attachment.disposition = Disposition("inline")
attachment.content_id = ContentId(image_name)

mail.add_attachment(attachment)

# Get a JSON-ready representation of the Mail object
mail_json = mail.get()

sg = sendgrid.SendGridAPIClient(api_key=SENDGRID_API_KEY)
response = sg.client.mail.send.post(request_body=mail_json) # can raise
if response.status_code != 202:
logger.warning(f"Unexpected status code {response.status_code}")


def send_email_with_smtplib(
user_email: str,
subject: str,
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:

# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
Expand Down Expand Up @@ -210,13 +289,10 @@ def send_email(
html_part = MIMEText(html_body, "html")
msg.attach(html_part)

try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
except Exception as e:
raise e
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)


def send_subscription_cancellation_email(user_email: str) -> None:
Expand Down Expand Up @@ -264,27 +340,13 @@ def send_subscription_cancellation_email(user_email: str) -> None:
)


def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
onyx_file: FileWithMimeType | None = None

try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME

onyx_file = OnyxRuntime.get_emailable_logo()

subject = f"Invitation to Join {application_name} Organization"
def build_user_email_invite(
from_email: str, to_email: str, application_name: str, auth_type: AuthType
) -> tuple[str, str]:
heading = "You've Been Invited!"

# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
message = f"<p>You have been invited by {from_email} to join an organization on {application_name}.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
Expand All @@ -309,7 +371,7 @@ def send_user_email_invite(
raise ValueError(f"Invalid auth type: {auth_type}")

cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={to_email}"

html_content = build_html_email(
application_name,
Expand All @@ -322,13 +384,36 @@ def send_user_email_invite(
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
f"You have been invited by {from_email} to join an organization on {application_name}.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
f"{WEB_DOMAIN}/auth/signup?email={to_email}\n"
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."

return text_content, html_content


def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME

onyx_file = OnyxRuntime.get_emailable_logo()

subject = f"Invitation to Join {application_name} Organization"

text_content, html_content = build_user_email_invite(
current_user.email, user_email, application_name, auth_type
)

send_email(
user_email,
subject,
Expand Down
16 changes: 8 additions & 8 deletions backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def _handle_search_tool_response_summary(
user_files: list[UserFile] | None = None,
loaded_user_files: list[InMemoryChatFile] | None = None,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
response_summary = cast(SearchResponseSummary, packet.response)

is_extended = isinstance(packet, ExtendedToolResponse)
dropped_inds = None

if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
top_docs = chunks_or_sections_to_search_docs(response_summary.top_sections)

deduped_docs = top_docs
if (
Expand Down Expand Up @@ -264,13 +264,13 @@ def _handle_search_tool_response_summary(
level, question_num = packet.level, packet.level_question_num
return (
QADocsResponse(
rephrased_query=response_sumary.rephrased_query,
rephrased_query=response_summary.rephrased_query,
top_documents=response_docs,
predicted_flow=response_sumary.predicted_flow,
predicted_search=response_sumary.predicted_search,
applied_source_filters=response_sumary.final_filters.source_type,
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
predicted_flow=response_summary.predicted_flow,
predicted_search=response_summary.predicted_search,
applied_source_filters=response_summary.final_filters.source_type,
applied_time_cutoff=response_summary.final_filters.time_cutoff,
recency_bias_multiplier=response_summary.recency_bias_multiplier,
level=level,
level_question_num=question_num,
),
Expand Down
4 changes: 3 additions & 1 deletion backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER

SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS]) or SENDGRID_API_KEY

# If set, Onyx will listen to the `expires_at` returned by the identity
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
# after this time has elapsed. Disabled since by default many auth providers
Expand Down
1 change: 1 addition & 0 deletions backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,4 @@ sentry-sdk==2.14.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6
prometheus_fastapi_instrumentator==7.1.0
sendgrid==6.11.0
36 changes: 36 additions & 0 deletions backend/tests/unit/onyx/auth/test_email.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from onyx.auth.email_utils import build_user_email_invite
from onyx.auth.email_utils import send_email
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.db.engine import SqlEngine
from onyx.server.runtime.onyx_runtime import OnyxRuntime


@pytest.mark.skip(
reason="This sends real emails, so only run when you really want to test this!"
)
def test_send_user_email_invite() -> None:
SqlEngine.init_engine(pool_size=20, max_overflow=5)

application_name = ONYX_DEFAULT_APPLICATION_NAME

onyx_file = OnyxRuntime.get_emailable_logo()

subject = f"Invitation to Join {application_name} Organization"

FROM_EMAIL = "noreply@onyx.app"
TO_EMAIL = "support@onyx.app"
text_content, html_content = build_user_email_invite(
FROM_EMAIL, TO_EMAIL, ONYX_DEFAULT_APPLICATION_NAME, AuthType.CLOUD
)

send_email(
TO_EMAIL,
subject,
html_content,
text_content,
mail_from=FROM_EMAIL,
inline_png=("logo.png", onyx_file.data),
)