6
6
import os
7
7
from tqdm import tqdm
8
8
import logging
9
+ from configs .database import get_key_database
10
+
11
+ keys_db = get_key_database ()
12
+ keys_collection = keys_db ["keys" ]
13
+
14
+ # Neo4j Connection Details
15
+ NEO4J_URI = keys_collection .find_one ({"_id" : "NEO4J_URI" })["api_key" ] # Replace with your Neo4j URI
16
+ NEO4J_USERNAME = "neo4j" # Replace with your Neo4j username
17
+ NEO4J_PASSWORD = keys_collection .find_one ({"_id" : "NEO4J_PASSWORD" })["api_key" ] # Replace with your Neo4j password
9
18
10
19
# Configure logging
11
20
logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
29
38
}
30
39
31
40
# Neo4j connection details from environment variables
32
- uri = "neo4j+ssc://7bf5a48e.databases.neo4j.io"
33
- AUTH = ("neo4j" , "oxsK7V5_86emZlYQlvCfQHfVWS95wXz29OhtU8GAdFc" )
41
+ uri = NEO4J_URI
42
+ AUTH = (NEO4J_USERNAME , NEO4J_PASSWORD )
34
43
35
44
# Initialize Neo4j driver
36
45
driver = GraphDatabase .driver (uri , auth = AUTH )
44
53
driver .close ()
45
54
exit (1 )
46
55
56
+
47
57
# Function to load node CSV files into DataFrames
48
58
def load_node_dataframes (csv_dir , node_types ):
49
59
node_dfs = {}
@@ -57,6 +67,7 @@ def load_node_dataframes(csv_dir, node_types):
57
67
logger .warning (f"CSV file for node type '{ node_type } ' not found in '{ csv_dir } '." )
58
68
return node_dfs
59
69
70
+
60
71
# Function to load relationships CSV file into a DataFrame
61
72
def load_relationships_data (csv_dir ):
62
73
relationships_file = os .path .join (csv_dir , 'relationships.csv' )
@@ -68,6 +79,7 @@ def load_relationships_data(csv_dir):
68
79
logger .warning (f"Relationships CSV file not found in '{ csv_dir } '." )
69
80
return None
70
81
82
+
71
83
# Function to create constraints
72
84
def create_constraints (driver ):
73
85
constraints = [
@@ -91,12 +103,15 @@ def create_constraints(driver):
91
103
logger .error (f"Failed to execute constraint '{ constraint } ': { e } " )
92
104
logger .info ("Constraints created or already exist." )
93
105
106
+
94
107
def standardize_relationship_types (df ):
95
108
if 'relationship_type' in df .columns :
96
109
original_types = df ['relationship_type' ].unique ()
97
- df ['relationship_type' ] = df ['relationship_type' ].str .upper ().str .replace (' ' , '_' ).str .replace ('[^A-Z0-9_]' , '' , regex = True )
110
+ df ['relationship_type' ] = df ['relationship_type' ].str .upper ().str .replace (' ' , '_' ).str .replace ('[^A-Z0-9_]' ,
111
+ '' , regex = True )
98
112
standardized_types = df ['relationship_type' ].unique ()
99
- logger .info (f"Standardized relationship types from { len (original_types )} to { len (standardized_types )} unique types." )
113
+ logger .info (
114
+ f"Standardized relationship types from { len (original_types )} to { len (standardized_types )} unique types." )
100
115
return df
101
116
102
117
@@ -116,12 +131,13 @@ def import_nodes_in_batches(tx, node_type, df, batch_size=1000):
116
131
df ['embedding' ] = df ['embedding' ].apply (lambda x : json .loads (x ) if pd .notnull (x ) else [])
117
132
data = df .to_dict ('records' )
118
133
for i in tqdm (range (0 , len (data ), batch_size ), desc = f"Importing { node_type } in batches" ):
119
- batch = data [i :i + batch_size ]
134
+ batch = data [i :i + batch_size ]
120
135
try :
121
136
tx .run (query , rows = batch )
122
- logger .info (f"Imported batch { i // batch_size + 1 } for node type '{ node_type } '." )
137
+ logger .info (f"Imported batch { i // batch_size + 1 } for node type '{ node_type } '." )
123
138
except Exception as e :
124
- logger .error (f"Error importing batch { i // batch_size + 1 } for node type '{ node_type } ': { e } " )
139
+ logger .error (f"Error importing batch { i // batch_size + 1 } for node type '{ node_type } ': { e } " )
140
+
125
141
126
142
# Function to create a mapping from ID to node type
127
143
def create_id_to_type_mapping (node_dfs ):
@@ -135,6 +151,7 @@ def create_id_to_type_mapping(node_dfs):
135
151
logger .info ("Created ID to node type mapping." )
136
152
return id_to_type
137
153
154
+
138
155
# Function to infer node types for relationships
139
156
def infer_node_types (rel_df , id_to_type ):
140
157
rel_df ['start_node_type' ] = rel_df ['start_node_id' ].apply (lambda x : id_to_type .get (int (x ), 'Unknown' ))
@@ -149,10 +166,11 @@ def infer_node_types(rel_df, id_to_type):
149
166
logger .warning (unknown_end )
150
167
return rel_df
151
168
169
+
152
170
def import_relationships_in_batches (tx , df , batch_size = 1000 ):
153
171
data = df .to_dict ('records' )
154
172
for i in tqdm (range (0 , len (data ), batch_size ), desc = "Importing relationships in batches" ):
155
- batch = data [i :i + batch_size ]
173
+ batch = data [i :i + batch_size ]
156
174
unwind_data = [
157
175
{
158
176
"start_id" : int (rel ['start_node_id' ]),
@@ -170,9 +188,9 @@ def import_relationships_in_batches(tx, df, batch_size=1000):
170
188
"""
171
189
try :
172
190
tx .run (query , rows = unwind_data )
173
- logger .info (f"Imported batch { i // batch_size + 1 } of relationships." )
191
+ logger .info (f"Imported batch { i // batch_size + 1 } of relationships." )
174
192
except Exception as e :
175
- logger .error (f"Error importing batch { i // batch_size + 1 } of relationships: { e } " )
193
+ logger .error (f"Error importing batch { i // batch_size + 1 } of relationships: { e } " )
176
194
177
195
178
196
# Main function to perform the import
@@ -198,25 +216,25 @@ def main():
198
216
if relationship_df is not None :
199
217
# Standardize relationship types
200
218
relationship_df = standardize_relationship_types (relationship_df )
201
-
219
+
202
220
# Infer node types if not present
203
221
if 'start_node_type' not in relationship_df .columns or 'end_node_type' not in relationship_df .columns :
204
222
logger .info ("Inferring 'start_node_type' and 'end_node_type' based on node IDs..." )
205
223
relationship_df = infer_node_types (relationship_df , id_to_type )
206
224
207
225
# Check for unknown node types
208
226
unknown_rels = relationship_df [
209
- (relationship_df ['start_node_type' ] == 'Unknown' ) |
227
+ (relationship_df ['start_node_type' ] == 'Unknown' ) |
210
228
(relationship_df ['end_node_type' ] == 'Unknown' )
211
- ]
229
+ ]
212
230
if not unknown_rels .empty :
213
231
logger .error ("Some relationships have unknown node types. Please verify your data." )
214
232
logger .error (unknown_rels )
215
233
# Skip unknown relationships
216
234
relationship_df = relationship_df [
217
- (relationship_df ['start_node_type' ] != 'Unknown' ) &
235
+ (relationship_df ['start_node_type' ] != 'Unknown' ) &
218
236
(relationship_df ['end_node_type' ] != 'Unknown' )
219
- ]
237
+ ]
220
238
221
239
# Import relationships
222
240
with driver .session () as session :
@@ -229,5 +247,6 @@ def main():
229
247
driver .close ()
230
248
logger .info ("Neo4j import completed." )
231
249
250
+
232
251
if __name__ == "__main__" :
233
252
main ()
0 commit comments