Skip to content
Merged
69 changes: 43 additions & 26 deletions backend/onyx/connectors/salesforce/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,42 +112,20 @@ def reconstruct_object_types(directory: str) -> dict[str, list[str] | None]:
@staticmethod
def _download_object_csvs(
sf_db: OnyxSalesforceSQLite,
all_types_to_filter: dict[str, bool],
directory: str,
parent_object_list: list[str],
sf_client: Salesforce,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> None:
all_types: set[str] = set(parent_object_list)

logger.info(
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
)

# This takes like 20 seconds
for parent_object_type in parent_object_list:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)

all_types.update(child_types)

# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")

logger.info(f"All object types: num={len(all_types)} list={all_types}")

# gc.collect()

# checkpoint - we've found all object types, now time to fetch the data
logger.info("Fetching CSVs for all object types")

# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_db=sf_db,
sf_client=sf_client,
object_types=all_types,
all_types_to_filter=all_types_to_filter,
start=start,
end=end,
target_dir=directory,
Expand Down Expand Up @@ -224,6 +202,30 @@ def _load_csvs_to_db(csv_directory: str, sf_db: OnyxSalesforceSQLite) -> set[str

return updated_ids

@staticmethod
def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]:
all_types: set[str] = set(parent_types)

# Step 1 - get all object types
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")

# This takes like 20 seconds
for parent_object_type in parent_types:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)

all_types.update(child_types)

# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")

logger.info(f"All object types: num={len(all_types)} list={all_types}")

# gc.collect()
return all_types

def _fetch_from_salesforce(
self,
temp_dir: str,
Expand All @@ -244,9 +246,24 @@ def _fetch_from_salesforce(
sf_db.apply_schema()
sf_db.log_stats()

# Step 1 - download
# Step 1.1 - add child object types + "User" type to the list of types
all_types = SalesforceConnector._get_all_types(
self.parent_object_list, self._sf_client
)

"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
all_types_to_filter: dict[str, bool] = {}
for sf_type in all_types:
if sf_db.has_at_least_one_object_of_type(sf_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slight preference for

all_types_to_filter[sf_type] = sf_db.has_at_least_one_object_of_type(sf_type)

all_types_to_filter[sf_type] = True
else:
all_types_to_filter[sf_type] = False

# Step 1.2 - bulk download the CSV for each object type
SalesforceConnector._download_object_csvs(
sf_db, temp_dir, self.parent_object_list, self._sf_client, start, end
sf_db, all_types_to_filter, temp_dir, self._sf_client, start, end
)
gc.collect()

Expand Down
26 changes: 11 additions & 15 deletions backend/onyx/connectors/salesforce/salesforce_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_db: OnyxSalesforceSQLite,
sf_client: Salesforce,
object_types: set[str],
all_types_to_filter: dict[str, bool],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,
Expand Down Expand Up @@ -219,20 +219,16 @@ def fetch_all_csvs_in_parallel(
)

time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if sf_db.has_at_least_one_object_of_type(sf_type):
if sf_type in created_date_types:
time_filter_for_each_object_type[sf_type] = created_date_time_filter
else:
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
else:

for sf_type, apply_filter in all_types_to_filter.items():
if not apply_filter:
time_filter_for_each_object_type[sf_type] = ""
continue

if sf_type in created_date_types:
time_filter_for_each_object_type[sf_type] = created_date_time_filter
else:
time_filter_for_each_object_type[sf_type] = last_modified_time_filter

# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
Expand All @@ -243,6 +239,6 @@ def fetch_all_csvs_in_parallel(
time_filter=time_filter_for_each_object_type[object_type],
target_dir=target_dir,
),
object_types,
all_types_to_filter.keys(),
)
return dict(results)
25 changes: 0 additions & 25 deletions backend/onyx/connectors/salesforce/sqlite_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,28 +572,3 @@ def _update_user_email_map(cursor: sqlite3.Cursor) -> None:
AND json_extract(data, '$.Email') IS NOT NULL
"""
)


# @contextmanager
# def get_db_connection(
# directory: str,
# isolation_level: str | None = None,
# ) -> Iterator[sqlite3.Connection]:
# """Get a database connection with proper isolation level and error handling.

# Args:
# isolation_level: SQLite isolation level. None = default "DEFERRED",
# can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
# """
# # 60 second timeout for locks
# conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0)

# if isolation_level is not None:
# conn.isolation_level = isolation_level
# try:
# yield conn
# except Exception:
# conn.rollback()
# raise
# finally:
# conn.close()
Loading