Skip to content

Commit f4a67f3

Browse files
addresses test review comments
1 parent fb5e1bd commit f4a67f3

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

tests/models/jais2/test_modeling_jais2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""Testing suite for the PyTorch Jais2 model."""
215

316
import gc
@@ -232,9 +245,12 @@ def test_compile_static_cache(self):
232245
)
233246

234247
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
235-
print(f"Static cache generated text: {generated_text}")
236-
237-
self.assertGreater(generated_ids.shape[1], input_ids.shape[1])
248+
249+
# Verify exact token count (deterministic)
250+
self.assertEqual(generated_ids.shape[1], input_ids.shape[1] + 10)
251+
# Verify generation produced reasonable output
252+
self.assertGreater(len(generated_text), len(prompt))
253+
self.assertTrue(generated_text.startswith(prompt))
238254

239255
del model
240256
backend_empty_cache(torch_device)

0 commit comments

Comments
 (0)