|
15 | 15 | from .bregman import sinkhorn |
16 | 16 | from .lp import emd |
17 | 17 | from .utils import unif, dist, kernel, cost_normalization |
18 | | -from .utils import check_params, deprecated, BaseEstimator |
| 18 | +from .utils import check_params, BaseEstimator |
19 | 19 | from .optim import cg |
20 | 20 | from .optim import gcg |
21 | 21 |
|
@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, |
740 | 740 | return A, b |
741 | 741 |
|
742 | 742 |
|
743 | | -@deprecated("The class OTDA is deprecated in 0.3.1 and will be " |
744 | | - "removed in 0.5" |
745 | | - "\n\tfor standard transport use class EMDTransport instead.") |
746 | | -class OTDA(object): |
747 | | - |
748 | | - """Class for domain adaptation with optimal transport as proposed in [5] |
749 | | -
|
750 | | -
|
751 | | - References |
752 | | - ---------- |
753 | | -
|
754 | | - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, |
755 | | - "Optimal Transport for Domain Adaptation," in IEEE Transactions on |
756 | | - Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 |
757 | | -
|
758 | | - """ |
759 | | - |
760 | | - def __init__(self, metric='sqeuclidean', norm=None): |
761 | | - """ Class initialization""" |
762 | | - self.xs = 0 |
763 | | - self.xt = 0 |
764 | | - self.G = 0 |
765 | | - self.metric = metric |
766 | | - self.norm = norm |
767 | | - self.computed = False |
768 | | - |
769 | | - def fit(self, xs, xt, ws=None, wt=None, max_iter=100000): |
770 | | - """Fit domain adaptation between samples is xs and xt |
771 | | - (with optional weights)""" |
772 | | - self.xs = xs |
773 | | - self.xt = xt |
774 | | - |
775 | | - if wt is None: |
776 | | - wt = unif(xt.shape[0]) |
777 | | - if ws is None: |
778 | | - ws = unif(xs.shape[0]) |
779 | | - |
780 | | - self.ws = ws |
781 | | - self.wt = wt |
782 | | - |
783 | | - self.M = dist(xs, xt, metric=self.metric) |
784 | | - self.M = cost_normalization(self.M, self.norm) |
785 | | - self.G = emd(ws, wt, self.M, max_iter) |
786 | | - self.computed = True |
787 | | - |
788 | | - def interp(self, direction=1): |
789 | | - """Barycentric interpolation for the source (1) or target (-1) samples |
790 | | -
|
791 | | - This Barycentric interpolation solves for each source (resp target) |
792 | | - sample xs (resp xt) the following optimization problem: |
793 | | -
|
794 | | - .. math:: |
795 | | - arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t) |
796 | | -
|
797 | | - where k is the index of the sample in xs |
798 | | -
|
799 | | - For the moment only squared euclidean distance is provided but more |
800 | | - metric could be used in the future. |
801 | | -
|
802 | | - """ |
803 | | - if direction > 0: # >0 then source to target |
804 | | - G = self.G |
805 | | - w = self.ws.reshape((self.xs.shape[0], 1)) |
806 | | - x = self.xt |
807 | | - else: |
808 | | - G = self.G.T |
809 | | - w = self.wt.reshape((self.xt.shape[0], 1)) |
810 | | - x = self.xs |
811 | | - |
812 | | - if self.computed: |
813 | | - if self.metric == 'sqeuclidean': |
814 | | - return np.dot(G / w, x) # weighted mean |
815 | | - else: |
816 | | - print( |
817 | | - "Warning, metric not handled yet, using weighted average") |
818 | | - return np.dot(G / w, x) # weighted mean |
819 | | - return None |
820 | | - else: |
821 | | - print("Warning, model not fitted yet, returning None") |
822 | | - return None |
823 | | - |
824 | | - def predict(self, x, direction=1): |
825 | | - """ Out of sample mapping using the formulation from [6] |
826 | | -
|
827 | | - For each sample x to map, it finds the nearest source sample xs and |
828 | | - map the samle x to the position xst+(x-xs) wher xst is the barycentric |
829 | | - interpolation of source sample xs. |
830 | | -
|
831 | | - References |
832 | | - ---------- |
833 | | -
|
834 | | - .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). |
835 | | - Regularized discrete optimal transport. SIAM Journal on Imaging |
836 | | - Sciences, 7(3), 1853-1882. |
837 | | -
|
838 | | - """ |
839 | | - if direction > 0: # >0 then source to target |
840 | | - xf = self.xt |
841 | | - x0 = self.xs |
842 | | - else: |
843 | | - xf = self.xs |
844 | | - x0 = self.xt |
845 | | - |
846 | | - D0 = dist(x, x0) # dist netween new samples an source |
847 | | - idx = np.argmin(D0, 1) # closest one |
848 | | - xf = self.interp(direction) # interp the source samples |
849 | | - # aply the delta to the interpolation |
850 | | - return xf[idx, :] + x - x0[idx, :] |
851 | | - |
852 | | - |
853 | | -@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be" |
854 | | - " removed in 0.5 \nUse class SinkhornTransport instead.") |
855 | | -class OTDA_sinkhorn(OTDA): |
856 | | - |
857 | | - """Class for domain adaptation with optimal transport with entropic |
858 | | - regularization |
859 | | -
|
860 | | -
|
861 | | - """ |
862 | | - |
863 | | - def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs): |
864 | | - """Fit regularized domain adaptation between samples is xs and xt |
865 | | - (with optional weights)""" |
866 | | - self.xs = xs |
867 | | - self.xt = xt |
868 | | - |
869 | | - if wt is None: |
870 | | - wt = unif(xt.shape[0]) |
871 | | - if ws is None: |
872 | | - ws = unif(xs.shape[0]) |
873 | | - |
874 | | - self.ws = ws |
875 | | - self.wt = wt |
876 | | - |
877 | | - self.M = dist(xs, xt, metric=self.metric) |
878 | | - self.M = cost_normalization(self.M, self.norm) |
879 | | - self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) |
880 | | - self.computed = True |
881 | | - |
882 | | - |
883 | | -@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be" |
884 | | - " removed in 0.5 \nUse class SinkhornLpl1Transport instead.") |
885 | | -class OTDA_lpl1(OTDA): |
886 | | - |
887 | | - """Class for domain adaptation with optimal transport with entropic and |
888 | | - group regularization""" |
889 | | - |
890 | | - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): |
891 | | - """Fit regularized domain adaptation between samples is xs and xt |
892 | | - (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit |
893 | | - parameters""" |
894 | | - self.xs = xs |
895 | | - self.xt = xt |
896 | | - |
897 | | - if wt is None: |
898 | | - wt = unif(xt.shape[0]) |
899 | | - if ws is None: |
900 | | - ws = unif(xs.shape[0]) |
901 | | - |
902 | | - self.ws = ws |
903 | | - self.wt = wt |
904 | | - |
905 | | - self.M = dist(xs, xt, metric=self.metric) |
906 | | - self.M = cost_normalization(self.M, self.norm) |
907 | | - self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) |
908 | | - self.computed = True |
909 | | - |
910 | | - |
911 | | -@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be" |
912 | | - " removed in 0.5 \nUse class SinkhornL1l2Transport instead.") |
913 | | -class OTDA_l1l2(OTDA): |
914 | | - |
915 | | - """Class for domain adaptation with optimal transport with entropic |
916 | | - and group lasso regularization""" |
917 | | - |
918 | | - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): |
919 | | - """Fit regularized domain adaptation between samples is xs and xt |
920 | | - (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit |
921 | | - parameters""" |
922 | | - self.xs = xs |
923 | | - self.xt = xt |
924 | | - |
925 | | - if wt is None: |
926 | | - wt = unif(xt.shape[0]) |
927 | | - if ws is None: |
928 | | - ws = unif(xs.shape[0]) |
929 | | - |
930 | | - self.ws = ws |
931 | | - self.wt = wt |
932 | | - |
933 | | - self.M = dist(xs, xt, metric=self.metric) |
934 | | - self.M = cost_normalization(self.M, self.norm) |
935 | | - self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) |
936 | | - self.computed = True |
937 | | - |
938 | | - |
939 | | -@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be" |
940 | | - " removed in 0.5 \nUse class MappingTransport instead.") |
941 | | -class OTDA_mapping_linear(OTDA): |
942 | | - |
943 | | - """Class for optimal transport with joint linear mapping estimation as in |
944 | | - [8] |
945 | | - """ |
946 | | - |
947 | | - def __init__(self): |
948 | | - """ Class initialization""" |
949 | | - |
950 | | - self.xs = 0 |
951 | | - self.xt = 0 |
952 | | - self.G = 0 |
953 | | - self.L = 0 |
954 | | - self.bias = False |
955 | | - self.computed = False |
956 | | - self.metric = 'sqeuclidean' |
957 | | - |
958 | | - def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs): |
959 | | - """ Fit domain adaptation between samples is xs and xt (with optional |
960 | | - weights)""" |
961 | | - self.xs = xs |
962 | | - self.xt = xt |
963 | | - self.bias = bias |
964 | | - |
965 | | - self.ws = unif(xs.shape[0]) |
966 | | - self.wt = unif(xt.shape[0]) |
967 | | - |
968 | | - self.G, self.L = joint_OT_mapping_linear( |
969 | | - xs, xt, mu=mu, eta=eta, bias=bias, **kwargs) |
970 | | - self.computed = True |
971 | | - |
972 | | - def mapping(self): |
973 | | - return lambda x: self.predict(x) |
974 | | - |
975 | | - def predict(self, x): |
976 | | - """ Out of sample mapping estimated during the call to fit""" |
977 | | - if self.computed: |
978 | | - if self.bias: |
979 | | - x = np.hstack((x, np.ones((x.shape[0], 1)))) |
980 | | - return x.dot(self.L) # aply the delta to the interpolation |
981 | | - else: |
982 | | - print("Warning, model not fitted yet, returning None") |
983 | | - return None |
984 | | - |
985 | | - |
986 | | -@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be" |
987 | | - " removed in 0.5 \nUse class MappingTransport instead.") |
988 | | -class OTDA_mapping_kernel(OTDA_mapping_linear): |
989 | | - |
990 | | - """Class for optimal transport with joint nonlinear mapping |
991 | | - estimation as in [8]""" |
992 | | - |
993 | | - def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian', |
994 | | - sigma=1, **kwargs): |
995 | | - """ Fit domain adaptation between samples is xs and xt """ |
996 | | - self.xs = xs |
997 | | - self.xt = xt |
998 | | - self.bias = bias |
999 | | - |
1000 | | - self.ws = unif(xs.shape[0]) |
1001 | | - self.wt = unif(xt.shape[0]) |
1002 | | - self.kernel = kerneltype |
1003 | | - self.sigma = sigma |
1004 | | - self.kwargs = kwargs |
1005 | | - |
1006 | | - self.G, self.L = joint_OT_mapping_kernel( |
1007 | | - xs, xt, mu=mu, eta=eta, bias=bias, **kwargs) |
1008 | | - self.computed = True |
1009 | | - |
1010 | | - def predict(self, x): |
1011 | | - """ Out of sample mapping estimated during the call to fit""" |
1012 | | - |
1013 | | - if self.computed: |
1014 | | - K = kernel( |
1015 | | - x, self.xs, method=self.kernel, sigma=self.sigma, |
1016 | | - **self.kwargs) |
1017 | | - if self.bias: |
1018 | | - K = np.hstack((K, np.ones((x.shape[0], 1)))) |
1019 | | - return K.dot(self.L) |
1020 | | - else: |
1021 | | - print("Warning, model not fitted yet, returning None") |
1022 | | - return None |
1023 | | - |
1024 | | - |
1025 | 743 | def distribution_estimation_uniform(X): |
1026 | 744 | """estimates a uniform distribution from an array of samples X |
1027 | 745 |
|
|
0 commit comments