diff --git a/pgamit/__init__.py b/pgamit/__init__.py index 65571a83..08a3fc22 100644 --- a/pgamit/__init__.py +++ b/pgamit/__init__.py @@ -2,6 +2,7 @@ __all__ = [ 'cluster', 'network', + 'plots', 'pyRinexName', 'Utils', 'pyJobServer', diff --git a/pgamit/network.py b/pgamit/network.py index 47b901d1..c2222787 100644 --- a/pgamit/network.py +++ b/pgamit/network.py @@ -38,7 +38,7 @@ from pgamit.pyGamitSession import GamitSession from pgamit.pyStation import StationCollection from pgamit.cluster import over_cluster, select_central_point, BisectingQMeans -from pgamit.NetPlots import plot_global_network +from pgamit.plots import plot_global_network BACKBONE_NET = 45 NET_LIMIT = 40 @@ -214,19 +214,19 @@ def make_clusters(self, points, stations, net_limit=NET_LIMIT): # append to a regular list for integer indexing at line ~400 cluster_ties.append(my_cluster_ties) - # put everything in a dictionary - clusters = {'centroids': points[central_points], - 'labels': cluster_labels, - 'stations': station_labels} - # define output path for plot solution_base = self.GamitConfig.gamitopt['solutions_dir'].rstrip('/') end_path = '/%s/%s/%s' % (self.date.yyyy(), self.date.ddd(), self.name) path = solution_base + end_path + '_cluster.png' # generate plot of the network segmentation - plot_global_network(central_points, OC, qmean.labels_, points, - output_path=path, lat_lon=False) + central_points = plot_global_network(central_points, OC, qmean.labels_, + points, output_path=path) + + # put everything in a dictionary + clusters = {'centroids': points[central_points], + 'labels': cluster_labels, + 'stations': station_labels} return clusters, cluster_ties diff --git a/pgamit/NetPlots.py b/pgamit/plots.py similarity index 86% rename from pgamit/NetPlots.py rename to pgamit/plots.py index fc377d55..387f51b6 100644 --- a/pgamit/NetPlots.py +++ b/pgamit/plots.py @@ -8,6 +8,7 @@ import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.basemap import Basemap +from sklearn.neighbors import NearestNeighbors from pgamit.Utils import ecef2lla @@ -86,10 +87,19 @@ def plot_global_network(central_points, OC, labels, points, # Flag centroid point remove = np.where(points == central_points[label])[0] points = points.tolist() - # remove centroid point so it's not repeated - points.pop(remove[0]) - # add central point to beginning so it's the central connection point - points.insert(0, central_points[label]) + try: + # remove centroid point so it's not repeated + points.pop(remove[0]) + # add same point to beginning of list + points.insert(0, central_points[label]) + except IndexError: + nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', + metric='haversine').fit(LL[points]) + idx = nbrs.kneighbors(LL[central_points[label]].reshape(1, -1), + return_distance=False) + # add central point to beginning as the central connection point + points.insert(0, points.pop(idx.squeeze())) + central_points[label] = points[0] nx.add_star(nodes[label], points) for position, proj in zip(positions, projs): mxy = np.zeros_like(LL[points]) @@ -112,3 +122,5 @@ def plot_global_network(central_points, OC, labels, points, fig.supxlabel("Figure runtime: " + ("%.2fs" % (t1 - t0)).lstrip("0")) plt.savefig(output_path) plt.close() + + return central_points