Skip to content

Commit e502183

Browse files
committed
Add labels to examples due to zoo deprecation, update usage to reflect new behavior with getRow
1 parent 0434362 commit e502183

File tree

8 files changed

+620
-1
lines changed

8 files changed

+620
-1
lines changed

tensorflow-keras-import-examples/src/main/java/org/deeplearning4j/modelimportexamples/tf/advanced/mobilenet/ImportMobileNetExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import org.apache.commons.io.FilenameUtils;
44
import org.datavec.image.loader.ImageLoader;
5-
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
5+
import org.deeplearning4j.modelimportexamples.util.imagenet.ImageNetLabels;
66
import org.nd4j.autodiff.samediff.SameDiff;
77
import org.nd4j.common.resources.Downloader;
88
import org.nd4j.linalg.api.ndarray.INDArray;
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
package org.deeplearning4j.modelimportexamples.util;
22+
23+
import org.deeplearning4j.common.resources.DL4JResources;
24+
import org.deeplearning4j.common.resources.ResourceType;
25+
import org.deeplearning4j.zoo.util.ClassPrediction;
26+
import org.deeplearning4j.zoo.util.Labels;
27+
import org.nd4j.common.base.Preconditions;
28+
import org.nd4j.common.resources.Downloader;
29+
import org.nd4j.linalg.api.ndarray.INDArray;
30+
import org.nd4j.linalg.factory.Nd4j;
31+
32+
import java.io.*;
33+
import java.net.URL;
34+
import java.util.ArrayList;
35+
import java.util.List;
36+
import java.util.Scanner;
37+
38+
public abstract class BaseLabels implements Labels {
39+
40+
protected ArrayList<String> labels;
41+
42+
/** Override {@link #getLabels()} when using this constructor. */
43+
protected BaseLabels() throws IOException {
44+
this.labels = getLabels();
45+
}
46+
47+
/**
48+
* No need to override anything with this constructor.
49+
*
50+
* @param textResource name of a resource containing labels as a list in a text file.
51+
* @throws IOException
52+
*/
53+
protected BaseLabels(String textResource) throws IOException {
54+
this.labels = getLabels(textResource);
55+
}
56+
57+
/**
58+
* Override to return labels when not calling {@link #BaseLabels(String)}.
59+
*/
60+
protected ArrayList<String> getLabels() throws IOException {
61+
return null;
62+
}
63+
64+
/**
65+
* Returns labels based on the text file resource.
66+
*/
67+
protected ArrayList<String> getLabels(String textResource) throws IOException {
68+
ArrayList<String> labels = new ArrayList<>();
69+
File resourceFile = getResourceFile(); //Download if required
70+
try (InputStream is = new BufferedInputStream(new FileInputStream(resourceFile)); Scanner s = new Scanner(is)) {
71+
while (s.hasNextLine()) {
72+
labels.add(s.nextLine());
73+
}
74+
}
75+
return labels;
76+
}
77+
78+
@Override
79+
public String getLabel(int n) {
80+
Preconditions.checkArgument(n >= 0 && n < labels.size(), "Invalid index: %s. Must be in range" +
81+
"0 <= n < %s", n, labels.size());
82+
return labels.get(n);
83+
}
84+
85+
@Override
86+
public List<List<ClassPrediction>> decodePredictions(INDArray predictions, int n) {
87+
if(predictions.rank() == 1){
88+
//Reshape 1d edge case to [1, nClasses] 2d
89+
predictions = predictions.reshape(1, predictions.length());
90+
}
91+
Preconditions.checkState(predictions.size(1) == labels.size(), "Invalid input array:" +
92+
" expected array with size(1) equal to numLabels (%s), got array with shape %s", labels.size(), predictions.shape());
93+
94+
long rows = predictions.size(0);
95+
long cols = predictions.size(1);
96+
if (predictions.isColumnVectorOrScalar()) {
97+
predictions = predictions.ravel();
98+
rows = (int) predictions.size(0);
99+
cols = (int) predictions.size(1);
100+
}
101+
List<List<ClassPrediction>> descriptions = new ArrayList<>();
102+
for (int batch = 0; batch < rows; batch++) {
103+
INDArray result = predictions.getRow(batch, true);
104+
result = Nd4j.vstack(Nd4j.linspace(result.dataType(), 0, cols, 1).reshape(1,cols), result);
105+
result = Nd4j.sortColumns(result, 1, false);
106+
List<ClassPrediction> current = new ArrayList<>();
107+
for (int i = 0; i < n; i++) {
108+
int label = result.getInt(0, i);
109+
double prob = result.getDouble(1, i);
110+
current.add(new ClassPrediction(label, getLabel(label), prob));
111+
}
112+
descriptions.add(current);
113+
}
114+
return descriptions;
115+
}
116+
117+
/**
118+
* @return URL of the resource to download
119+
*/
120+
protected abstract URL getURL();
121+
122+
/**
123+
* @return Name of the resource (used for inferring local storage parent directory)
124+
*/
125+
protected abstract String resourceName();
126+
127+
/**
128+
* @return MD5 of the resource at getURL()
129+
*/
130+
protected abstract String resourceMD5();
131+
132+
/**
133+
* Download the resource at getURL() to the local resource directory, and return the local copy as a File
134+
*
135+
* @return File of the local resource
136+
*/
137+
protected File getResourceFile() {
138+
139+
URL url = getURL();
140+
String urlString = url.toString();
141+
String filename = urlString.substring(urlString.lastIndexOf('/')+1);
142+
File resourceDir = DL4JResources.getDirectory(ResourceType.RESOURCE, resourceName());
143+
File localFile = new File(resourceDir, filename);
144+
145+
String expMD5 = resourceMD5();
146+
if(localFile.exists()) {
147+
try{
148+
//empty string means ignore the MD5
149+
if(Downloader.checkMD5OfFile(expMD5, localFile)) {
150+
return localFile;
151+
}
152+
} catch (IOException e){
153+
//Ignore
154+
}
155+
//MD5 failed
156+
localFile.delete();
157+
}
158+
159+
//Download
160+
try {
161+
Downloader.download(resourceName(), url, localFile, expMD5, 3);
162+
} catch (IOException e){
163+
throw new RuntimeException("Error downloading labels",e);
164+
}
165+
166+
return localFile;
167+
}
168+
169+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
package org.deeplearning4j.modelimportexamples.util;
22+
23+
import lombok.AllArgsConstructor;
24+
import lombok.Data;
25+
26+
import java.util.Objects;
27+
28+
public class ClassPrediction {
29+
30+
private int number;
31+
private String label;
32+
33+
@Override
34+
public boolean equals(Object o) {
35+
if (this == o) return true;
36+
if (o == null || getClass() != o.getClass()) return false;
37+
ClassPrediction that = (ClassPrediction) o;
38+
return number == that.number && Double.compare(that.probability, probability) == 0 && Objects.equals(label, that.label);
39+
}
40+
41+
@Override
42+
public int hashCode() {
43+
return Objects.hash(number, label, probability);
44+
}
45+
46+
public int getNumber() {
47+
return number;
48+
}
49+
50+
public void setNumber(int number) {
51+
this.number = number;
52+
}
53+
54+
public String getLabel() {
55+
return label;
56+
}
57+
58+
public void setLabel(String label) {
59+
this.label = label;
60+
}
61+
62+
public double getProbability() {
63+
return probability;
64+
}
65+
66+
public void setProbability(double probability) {
67+
this.probability = probability;
68+
}
69+
70+
private double probability;
71+
72+
@Override
73+
public String toString() {
74+
return "ClassPrediction(number=" + number + ",label=" + label + ",probability=" + probability + ")";
75+
}
76+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
package org.deeplearning4j.modelimportexamples.util;
22+
23+
import org.deeplearning4j.zoo.util.ClassPrediction;
24+
import org.nd4j.linalg.api.ndarray.INDArray;
25+
26+
import java.util.List;
27+
28+
public interface Labels {
29+
30+
/**
31+
* Returns the description of the nth class from the classes of a dataset.
32+
* @param n
33+
* @return label description
34+
*/
35+
String getLabel(int n);
36+
37+
/**
38+
* Given predictions from the trained model this method will return a list
39+
* of the top n matches and the respective probabilities.
40+
* @param predictions raw
41+
* @return decoded predictions
42+
*/
43+
List<List<ClassPrediction>> decodePredictions(INDArray predictions, int n);
44+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
package org.deeplearning4j.modelimportexamples.util.darknet;
22+
23+
import org.deeplearning4j.common.resources.DL4JResources;
24+
import org.deeplearning4j.zoo.util.BaseLabels;
25+
26+
import java.io.IOException;
27+
import java.net.MalformedURLException;
28+
import java.net.URL;
29+
30+
public class COCOLabels extends BaseLabels {
31+
32+
public COCOLabels() throws IOException {
33+
super("coco.names");
34+
}
35+
36+
@Override
37+
protected URL getURL() {
38+
try {
39+
return DL4JResources.getURL("resources/darknet/coco.names");
40+
} catch (MalformedURLException e){
41+
throw new RuntimeException(e);
42+
}
43+
}
44+
45+
@Override
46+
protected String resourceName() {
47+
return "darknet";
48+
}
49+
50+
@Override
51+
protected String resourceMD5() {
52+
return "4caf6834300c8b2ff19964b36e54d637";
53+
}
54+
}

0 commit comments

Comments
 (0)