@@ -118,16 +118,16 @@ def print_G(G, xs, ys, xt):
118118otda = ot .da .JCPOTTransport (reg_e = 1e-2 , max_iter = 1000 , metric = 'sqeuclidean' , tol = 1e-9 , verbose = True , log = True )
119119otda .fit (all_Xr , all_Yr , xt )
120120
121- ws1 = otda .proportions_ .dot (otda .log_ ['all_domains ' ][0 ][ 'D2' ])
122- ws2 = otda .proportions_ .dot (otda .log_ ['all_domains ' ][1 ][ 'D2' ])
121+ ws1 = otda .proportions_ .dot (otda .log_ ['D2 ' ][0 ])
122+ ws2 = otda .proportions_ .dot (otda .log_ ['D2 ' ][1 ])
123123
124124pl .figure (3 )
125125pl .clf ()
126126plot_ax (dec1 , 'Source 1' )
127127plot_ax (dec2 , 'Source 2' )
128128plot_ax (dect , 'Target' )
129- print_G (ot .bregman .sinkhorn (ws1 , [], otda .log_ ['all_domains ' ][0 ][ 'M' ], reg = 1e-2 ), xs1 , ys1 , xt )
130- print_G (ot .bregman .sinkhorn (ws2 , [], otda .log_ ['all_domains ' ][1 ][ 'M' ], reg = 1e-2 ), xs2 , ys2 , xt )
129+ print_G (ot .bregman .sinkhorn (ws1 , [], otda .log_ ['M ' ][0 ], reg = 1e-2 ), xs1 , ys1 , xt )
130+ print_G (ot .bregman .sinkhorn (ws2 , [], otda .log_ ['M ' ][1 ], reg = 1e-2 ), xs2 , ys2 , xt )
131131pl .scatter (xs1 [:, 0 ], xs1 [:, 1 ], c = ys1 , s = 35 , marker = 'x' , cmap = 'Set1' , vmax = 9 )
132132pl .scatter (xs2 [:, 0 ], xs2 [:, 1 ], c = ys2 , s = 35 , marker = '+' , cmap = 'Set1' , vmax = 9 )
133133pl .scatter (xt [:, 0 ], xt [:, 1 ], c = yt , s = 35 , marker = 'o' , cmap = 'Set1' , vmax = 9 )
@@ -146,16 +146,16 @@ def print_G(G, xs, ys, xt):
146146# ----------------------------------------------------------------------------
147147h_res = np .array ([1 - pt , pt ])
148148
149- ws1 = h_res .dot (otda .log_ ['all_domains ' ][0 ][ 'D2' ])
150- ws2 = h_res .dot (otda .log_ ['all_domains ' ][1 ][ 'D2' ])
149+ ws1 = h_res .dot (otda .log_ ['D2 ' ][0 ])
150+ ws2 = h_res .dot (otda .log_ ['D2 ' ][1 ])
151151
152152pl .figure (4 )
153153pl .clf ()
154154plot_ax (dec1 , 'Source 1' )
155155plot_ax (dec2 , 'Source 2' )
156156plot_ax (dect , 'Target' )
157- print_G (ot .bregman .sinkhorn (ws1 , [], otda .log_ ['all_domains ' ][0 ][ 'M' ], reg = 1e-2 ), xs1 , ys1 , xt )
158- print_G (ot .bregman .sinkhorn (ws2 , [], otda .log_ ['all_domains ' ][1 ][ 'M' ], reg = 1e-2 ), xs2 , ys2 , xt )
157+ print_G (ot .bregman .sinkhorn (ws1 , [], otda .log_ ['M ' ][0 ], reg = 1e-2 ), xs1 , ys1 , xt )
158+ print_G (ot .bregman .sinkhorn (ws2 , [], otda .log_ ['M ' ][1 ], reg = 1e-2 ), xs2 , ys2 , xt )
159159pl .scatter (xs1 [:, 0 ], xs1 [:, 1 ], c = ys1 , s = 35 , marker = 'x' , cmap = 'Set1' , vmax = 9 )
160160pl .scatter (xs2 [:, 0 ], xs2 [:, 1 ], c = ys2 , s = 35 , marker = '+' , cmap = 'Set1' , vmax = 9 )
161161pl .scatter (xt [:, 0 ], xt [:, 1 ], c = yt , s = 35 , marker = 'o' , cmap = 'Set1' , vmax = 9 )
0 commit comments