1
1
from __future__ import annotations
2
2
from typing import TYPE_CHECKING
3
+ from PySide6 .QtCore import Qt
3
4
from PySide6 .QtWidgets import (
4
5
QDialog ,
5
6
QVBoxLayout ,
13
14
)
14
15
15
16
if TYPE_CHECKING :
16
- from .object_explorer .asset_base import AssetBase
17
+ from .object_explorer import AssetItem , AttackerItem
18
+ from maltoolbox .language import LanguageGraph , LanguageGraphAsset
19
+ from maltoolbox .model import Model , ModelAsset
17
20
18
21
class ConnectionDialog (QDialog ):
19
22
def filter_items (self , text ):
@@ -26,16 +29,15 @@ def ok_button_clicked(self):
26
29
class AssociationConnectionDialog (ConnectionDialog ):
27
30
def __init__ (
28
31
self ,
29
- start_item : AssetBase ,
30
- end_item : AssetBase ,
31
- lang_graph ,
32
- lcs , model ,
32
+ start_item : AssetItem ,
33
+ end_item : AssetItem ,
34
+ lang_graph : LanguageGraph ,
35
+ model : Model ,
33
36
parent = None
34
37
):
35
38
super ().__init__ (parent )
36
39
37
- self .lang_graph = lang_graph
38
- self .lcs = lcs
40
+ self .lang_graph : LanguageGraph = lang_graph
39
41
self .model = model
40
42
41
43
self .setWindowTitle ("Select Association Type" )
@@ -46,75 +48,38 @@ def __init__(
46
48
47
49
self .association_list_widget = QListWidget ()
48
50
49
- start_asset = start_item .asset
50
- end_asset = end_item .asset
51
- self .start_asset_type = start_asset .type
52
- self .end_asset_type = end_asset .type
53
- self .start_asset_name = start_asset .name
54
- self .end_asset_name = end_asset .name
51
+ self .start_asset : ModelAsset = start_item .asset
52
+ self .end_asset : ModelAsset = end_item .asset
53
+ self .field_name = None
54
+
55
55
self .layout = QVBoxLayout ()
56
- self .label = \
57
- QLabel (f"{ self .start_asset_name } : { self .end_asset_name } " )
56
+ self .label = (
57
+ QLabel (f"{ self .start_asset .name } -> { self .end_asset .name } " )
58
+ )
58
59
self .layout .addWidget (self .label )
59
60
self .filter_edit = QLineEdit ()
60
61
self .filter_edit .setPlaceholderText ("Type to filter..." )
61
62
self .filter_edit .textChanged .connect (self .filter_items )
62
63
self .layout .addWidget (self .filter_edit )
63
- lang_graph_start_asset = next (
64
- (asset for asset in self .lang_graph .assets
65
- if asset .name == start_asset .type ), None
66
- )
67
- if lang_graph_start_asset is None :
68
- raise LookupError (f'Failed to find asset "{ start_asset .type } " '
69
- 'in language graph.' )
70
- lang_graph_end_asset = next (
71
- (asset for asset in self .lang_graph .assets
72
- if asset .name == end_asset .type ), None
73
- )
74
- if lang_graph_end_asset is None :
75
- raise LookupError (f'Failed to find asset "{ end_asset .type } " '
76
- 'in language graph.' )
77
-
78
- self ._str_to_assoc = {}
79
- for assoc in lang_graph_start_asset .associations :
80
- asset_pairs = []
81
- opposite_asset = assoc .get_opposite_asset (lang_graph_start_asset )
82
- # Check if the other side of the association matches the other end
83
- # and if the exact association does not already exist in the
84
- # model.
85
- if lang_graph_end_asset .is_subasset_of (opposite_asset ):
86
- print ("IDENTIFIED MATCH ++++++++++++" )
87
- if lang_graph_start_asset .is_subasset_of (assoc .left_field .asset ):
88
- asset_pairs .append ((start_asset , end_asset ))
89
- else :
90
- asset_pairs .append ((end_asset , start_asset ))
91
- if lang_graph_start_asset .is_subasset_of (opposite_asset ):
92
- # The association could be applied either way, add the
93
- # reverse association as well.
94
- other_asset = assoc .get_opposite_asset (opposite_asset )
95
- # Check if the other side of the association matches the other end
96
- # and if the exact association does not already exist in the
97
- # model.
98
- if lang_graph_end_asset .is_subasset_of (other_asset ):
99
- print ("REVERSE ASSOC ++++++++++++" )
100
- # We need to create the reverse association as well
101
- asset_pairs .append ((end_asset , start_asset ))
102
- for (left_asset , right_asset ) in asset_pairs :
103
- if not self .model .association_exists_between_assets (
104
- assoc .name ,
105
- left_asset ,
106
- right_asset ):
107
- formatted_assoc_str = left_asset .name + "." + \
108
- assoc .left_field .fieldname + "-->" + \
109
- assoc .name + "-->" + \
110
- right_asset .name + "." + \
111
- assoc .right_field .fieldname
112
- self ._str_to_assoc [formatted_assoc_str ] = (
113
- assoc ,
114
- left_asset ,
115
- right_asset
116
- )
117
- self .association_list_widget .addItem (QListWidgetItem (formatted_assoc_str ))
64
+
65
+ possible_assocs = self .start_asset .lg_asset .associations
66
+ for fieldname , association in possible_assocs .items ():
67
+ field = association .get_field (fieldname )
68
+
69
+ # If assoc ends with end_assets type, give that assoc
70
+ # as option in list widget
71
+ if field .asset == self .end_asset .lg_asset :
72
+ assoc_list_item = QListWidgetItem (self .start_asset .name + "." + fieldname + " = " + self .end_asset .name )
73
+ assoc_list_item .setData (
74
+ Qt .UserRole ,
75
+ {
76
+ 'from' : self .start_asset ,
77
+ 'to' : self .end_asset ,
78
+ 'fieldname' : fieldname
79
+ }
80
+ )
81
+ self .association_list_widget .addItem (assoc_list_item )
82
+
118
83
self .layout .addWidget (self .association_list_widget )
119
84
120
85
button_layout = QHBoxLayout ()
@@ -142,65 +107,53 @@ def filter_items(self, text):
142
107
143
108
def ok_button_clicked (self ):
144
109
selected_item = self .association_list_widget .currentItem ()
110
+
145
111
if selected_item :
146
- selected_association_text = selected_item .text ()
147
- # QMessageBox.information(self, "Selected Item", f"You selected: {selected_association_text}")
148
-
149
- (assoc , left_asset , right_asset ) = \
150
- self ._str_to_assoc [selected_association_text ]
151
- # TODO: Create association based on its full name instead in order
152
- # to avoid conflicts when multiple associations with the same name
153
- # exist.
154
- association = getattr (self .lcs .ns , assoc .name )()
155
- print (
156
- f'N:{ assoc .name } LF:{ assoc .left_field .fieldname } '
157
- f'LA:{ left_asset .name } RF:{ assoc .right_field .fieldname } '
158
- f'RA:{ right_asset .name } '
159
- )
160
- setattr (association , assoc .left_field .fieldname , [left_asset ])
161
- setattr (association , assoc .right_field .fieldname , [right_asset ])
162
- selected_item .association = association
163
- # self.model.add_association(association)
112
+ data = selected_item .data (Qt .UserRole )
113
+
114
+ from_asset : ModelAsset = data .get ('from' )
115
+ to_asset : ModelAsset = data .get ('to' )
116
+ self .field_name : str = data .get ('fieldname' )
117
+
118
+ print (f'{ from_asset } .{ self .field_name } = { to_asset } chosen' )
119
+
120
+
164
121
self .accept ()
165
122
166
123
class EntrypointConnectionDialog (ConnectionDialog ):
167
124
def __init__ (
168
125
self ,
169
- attacker_item ,
170
- asset_item ,
171
- lang_graph ,
172
- lcs ,
126
+ attacker_item : AttackerItem ,
127
+ asset_item : AssetItem ,
128
+ lang_graph : LanguageGraph ,
173
129
model ,
174
130
parent = None
175
131
):
176
132
super ().__init__ (parent )
177
133
178
134
self .lang_graph = lang_graph
179
- self .lcs = lcs
180
135
self .model = model
181
136
182
137
self .setWindowTitle ("Select Entry Point" )
183
138
self .setMinimumWidth (300 )
184
139
185
- print (f'Attacker ITEM TYPE { attacker_item .asset_type } ' )
186
- print (f'Asset ITEM TYPE { asset_item .asset_type } ' )
187
-
188
140
self .attack_step_list_widget = QListWidget ()
189
- attacker = attacker_item .attackerAttachment
141
+ attacker = attacker_item .attacker
190
142
191
143
if asset_item .asset is not None :
192
- asset_type = \
193
- self .lang_graph .get_asset_by_name (asset_item .asset .type )
144
+ asset_type = self .lang_graph .assets [asset_item .asset .type ]
194
145
195
146
# Find asset attack steps already part of attacker entry points
196
147
entry_point_tuple = attacker .get_entry_point_tuple (
197
- asset_item .asset )
148
+ asset_item .asset
149
+ )
150
+
198
151
if entry_point_tuple is not None :
199
152
entry_point_attack_steps = entry_point_tuple [1 ]
200
153
else :
201
154
entry_point_attack_steps = []
202
155
203
- for attack_step in asset_type .attack_steps :
156
+ for attack_step in asset_type .attack_steps . values () :
204
157
if attack_step .type not in ['or' , 'and' ]:
205
158
continue
206
159
0 commit comments