44
55
66import numpy as np
7- import matplotlib .pylab as pl
7+ import matplotlib .pylab as plt
88from matplotlib import gridspec
99
1010
11- def plot1D_mat (a ,b , M , title = '' ):
12- """ Plot matrix M with the source and target 1D distribution
13-
14- Creates a subplot with the source distribution a on the left and
11+ def plot1D_mat (a , b , M , title = '' ):
12+ """ Plot matrix M with the source and target 1D distribution
13+
14+ Creates a subplot with the source distribution a on the left and
1515 target distribution b on the tot. The matrix M is shown in between.
16-
17-
16+
17+
1818 Parameters
1919 ----------
20-
21- a : np.array (na,)
20+ a : np.array, shape (na,)
2221 Source distribution
23- b : np.array (nb,)
24- Target distribution
25- M : np.array (na,nb)
22+ b : np.array, shape (nb,)
23+ Target distribution
24+ M : np.array, shape (na,nb)
2625 Matrix to plot
27-
28-
29-
3026 """
31-
32- na = M .shape [0 ]
33- nb = M .shape [1 ]
34-
27+ na , nb = M .shape
28+
3529 gs = gridspec .GridSpec (3 , 3 )
36-
37-
38- xa = np .arange (na )
39- xb = np .arange (nb )
40-
41-
42- ax1 = pl .subplot (gs [0 ,1 :])
43- pl .plot (xb ,b ,'r' ,label = 'Target distribution' )
44- pl .yticks (())
45- pl .title (title )
46-
47- #pl.axis('off')
48-
49- ax2 = pl .subplot (gs [1 :,0 ])
50- pl .plot (a ,xa ,'b' ,label = 'Source distribution' )
51- pl .gca ().invert_xaxis ()
52- pl .gca ().invert_yaxis ()
53- pl .xticks (())
54- #pl.ylim((0,n))
55- #pl.axis('off')
56-
57- pl .subplot (gs [1 :,1 :],sharex = ax1 ,sharey = ax2 )
58- pl .imshow (M ,interpolation = 'nearest' )
59-
60- pl .xlim ((0 ,nb ))
61-
62-
63- def plot2D_samples_mat (xs ,xt ,G ,thr = 1e-8 ,** kwargs ):
30+
31+ xa = np .arange (na )
32+ xb = np .arange (nb )
33+
34+ ax1 = plt .subplot (gs [0 , 1 :])
35+ plt .plot (xb , b , 'r' , label = 'Target distribution' )
36+ plt .yticks (())
37+ plt .title (title )
38+
39+ ax2 = plt .subplot (gs [1 :, 0 ])
40+ plt .plot (a , xa , 'b' , label = 'Source distribution' )
41+ plt .gca ().invert_xaxis ()
42+ plt .gca ().invert_yaxis ()
43+ plt .xticks (())
44+
45+ plt .subplot (gs [1 :, 1 :], sharex = ax1 , sharey = ax2 )
46+ plt .imshow (M , interpolation = 'nearest' )
47+ plt .axis ('off' )
48+
49+ plt .xlim ((0 , nb ))
50+ plt .tight_layout ()
51+ plt .subplots_adjust (wspace = 0. , hspace = 0.2 )
52+
53+
54+ def plot2D_samples_mat (xs , xt , G , thr = 1e-8 , ** kwargs ):
6455 """ Plot matrix M in 2D with lines using alpha values
65-
66- Plot lines between source and target 2D samples with a color
56+
57+ Plot lines between source and target 2D samples with a color
6758 proportional to the value of the matrix G between samples.
68-
69-
59+
60+
7061 Parameters
7162 ----------
72-
73- xs : np.array (ns,2)
63+ xs : ndarray, shape (ns,2)
7464 Source samples positions
75- b : np.array (nt,2)
65+ b : ndarray, shape (nt,2)
7666 Target samples positions
77- G : np.array (na,nb)
67+ G : ndarray, shape (na,nb)
7868 OT matrix
7969 thr : float, optional
8070 threshold above which the line is drawn
8171 **kwargs : dict
82- paameters given to the plot functions (default color is black if nothing given)
83-
72+ paameters given to the plot functions (default color is black if
73+ nothing given)
8474 """
85- if ('color' not in kwargs ) and ('c' not in kwargs ):
86- kwargs ['color' ]= 'k'
87- mx = G .max ()
75+ if ('color' not in kwargs ) and ('c' not in kwargs ):
76+ kwargs ['color' ] = 'k'
77+ mx = G .max ()
8878 for i in range (xs .shape [0 ]):
8979 for j in range (xt .shape [0 ]):
90- if G [i ,j ] / mx > thr :
91- pl .plot ([xs [i ,0 ],xt [j ,0 ]],[xs [i ,1 ],xt [j ,1 ]],alpha = G [ i , j ] / mx , ** kwargs )
92-
80+ if G [i , j ] / mx > thr :
81+ plt .plot ([xs [i , 0 ], xt [j , 0 ]], [xs [i , 1 ], xt [j , 1 ]],
82+ alpha = G [ i , j ] / mx , ** kwargs )
0 commit comments