Skip to content

Commit 2933531

Browse files
committed
add notebook
1 parent c48829d commit 2933531

File tree

3 files changed

+2025
-731
lines changed

3 files changed

+2025
-731
lines changed

examples/Demo_2D_OT_DomainAdaptation.ipynb

Lines changed: 217 additions & 0 deletions
Large diffs are not rendered by default.

examples/demo_OTDA_classes.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111

1212
#%% parameters
1313

14-
n=150 # nb bins
14+
n=150 # nb samples in source and target datasets
1515

1616
xs,ys=ot.datasets.get_data_classif('3gauss',n)
1717
xt,yt=ot.datasets.get_data_classif('3gauss2',n)
1818

19-
a,b = ot.unif(n),ot.unif(n)
20-
# loss matrix
21-
M=ot.dist(xs,xt)
22-
#M/=M.max()
19+
20+
2321

2422
#%% plot samples
2523

@@ -38,17 +36,13 @@
3836

3937
#%% OT estimation
4038

41-
# EMD
42-
43-
44-
da_emd=ot.da.OTDA()
45-
da_emd.fit(xs,xt)
46-
47-
# interpolate samples
48-
xst0=da_emd.interp()
39+
# LP problem
40+
da_emd=ot.da.OTDA() # init class
41+
da_emd.fit(xs,xt) # fit distributions
42+
xst0=da_emd.interp() # interpolation of source samples
4943

5044

51-
# sinkhorn
45+
# sinkhorn regularization
5246
lambd=1e-1
5347
da_entrop=ot.da.OTDA_sinkhorn()
5448
da_entrop.fit(xs,xt,reg=lambd)

0 commit comments

Comments
 (0)