|
| 1 | +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
1 | 14 | """Testing suite for the PyTorch Jais2 model.""" |
2 | 15 |
|
3 | 16 | import gc |
@@ -232,9 +245,12 @@ def test_compile_static_cache(self): |
232 | 245 | ) |
233 | 246 |
|
234 | 247 | generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
235 | | - print(f"Static cache generated text: {generated_text}") |
236 | | - |
237 | | - self.assertGreater(generated_ids.shape[1], input_ids.shape[1]) |
| 248 | + |
| 249 | + # Verify exact token count (deterministic) |
| 250 | + self.assertEqual(generated_ids.shape[1], input_ids.shape[1] + 10) |
| 251 | + # Verify generation produced reasonable output |
| 252 | + self.assertGreater(len(generated_text), len(prompt)) |
| 253 | + self.assertTrue(generated_text.startswith(prompt)) |
238 | 254 |
|
239 | 255 | del model |
240 | 256 | backend_empty_cache(torch_device) |
|
0 commit comments