77import sys
88import tempfile
99import time
10+ from enum import Enum
1011from pathlib import Path
1112from types import TracebackType
1213from typing import (
@@ -444,14 +445,18 @@ def assert_bash_exec(
444445class bash_env_saved :
445446 counter : int = 0
446447
448+ class saved_state (Enum ):
449+ ChangesDetected = 1
450+ ChangesIgnored = 2
451+
447452 def __init__ (self , bash : pexpect .spawn , sendintr : bool = False ):
448453 bash_env_saved .counter += 1
449454 self .prefix : str = "_comp__test_%d" % bash_env_saved .counter
450455
451456 self .bash = bash
452457 self .cwd_changed : bool = False
453- self .saved_shopt : Dict [str , int ] = {}
454- self .saved_variables : Dict [str , int ] = {}
458+ self .saved_shopt : Dict [str , bash_env_saved . saved_state ] = {}
459+ self .saved_variables : Dict [str , bash_env_saved . saved_state ] = {}
455460 self .sendintr = sendintr
456461
457462 self .noexcept : bool = False
@@ -516,14 +521,19 @@ def _save_cwd(self):
516521 self ._copy_variable ("PWD" , "%s_OLDPWD" % self .prefix )
517522
518523 def _check_shopt (self , name : str ):
524+ if (
525+ self .saved_shopt [name ]
526+ != bash_env_saved .saved_state .ChangesDetected
527+ ):
528+ return
519529 self ._safe_assert (
520530 '[[ $(shopt -p %s) == "${%s_NEWSHOPT_%s}" ]]'
521531 % (name , self .prefix , name ),
522532 )
523533
524534 def _unprotect_shopt (self , name : str ):
525535 if name not in self .saved_shopt :
526- self .saved_shopt [name ] = 1
536+ self .saved_shopt [name ] = bash_env_saved . saved_state . ChangesDetected
527537 self ._safe_exec (
528538 "%s_OLDSHOPT_%s=$(shopt -p %s || true)"
529539 % (self .prefix , name , name ),
@@ -538,6 +548,11 @@ def _protect_shopt(self, name: str):
538548 )
539549
540550 def _check_variable (self , varname : str ):
551+ if (
552+ self .saved_variables [varname ]
553+ != bash_env_saved .saved_state .ChangesDetected
554+ ):
555+ return
541556 try :
542557 self ._safe_assert (
543558 '[[ ${%s-%s} == "${%s_NEWVAR_%s-%s}" ]]'
@@ -556,7 +571,9 @@ def _check_variable(self, varname: str):
556571
557572 def _unprotect_variable (self , varname : str ):
558573 if varname not in self .saved_variables :
559- self .saved_variables [varname ] = 1
574+ self .saved_variables [
575+ varname
576+ ] = bash_env_saved .saved_state .ChangesDetected
560577 self ._copy_variable (
561578 varname , "%s_OLDVAR_%s" % (self .prefix , varname )
562579 )
@@ -581,13 +598,6 @@ def _restore_env(self):
581598 self ._unset_variable ("%s_OLDPWD" % self .prefix )
582599 self .cwd_changed = False
583600
584- for name in self .saved_shopt :
585- self ._check_shopt (name )
586- self ._safe_exec ('eval "$%s_OLDSHOPT_%s"' % (self .prefix , name ))
587- self ._unset_variable ("%s_OLDSHOPT_%s" % (self .prefix , name ))
588- self ._unset_variable ("%s_NEWSHOPT_%s" % (self .prefix , name ))
589- self .saved_shopt = {}
590-
591601 for varname in self .saved_variables :
592602 self ._check_variable (varname )
593603 self ._copy_variable (
@@ -597,6 +607,13 @@ def _restore_env(self):
597607 self ._unset_variable ("%s_NEWVAR_%s" % (self .prefix , varname ))
598608 self .saved_variables = {}
599609
610+ for name in self .saved_shopt :
611+ self ._check_shopt (name )
612+ self ._safe_exec ('eval "$%s_OLDSHOPT_%s"' % (self .prefix , name ))
613+ self ._unset_variable ("%s_OLDSHOPT_%s" % (self .prefix , name ))
614+ self ._unset_variable ("%s_NEWSHOPT_%s" % (self .prefix , name ))
615+ self .saved_shopt = {}
616+
600617 self .noexcept = False
601618 if self .captured_error :
602619 raise self .captured_error
@@ -616,13 +633,23 @@ def shopt(self, name: str, value: bool):
616633 self ._safe_exec ("shopt -u %s" % name )
617634 self ._protect_shopt (name )
618635
636+ def save_shopt (self , name : str ):
637+ self ._unprotect_shopt (name )
638+ self .saved_shopt [name ] = bash_env_saved .saved_state .ChangesIgnored
639+
619640 def write_variable (self , varname : str , new_value : str , quote : bool = True ):
620641 if quote :
621642 new_value = shlex .quote (new_value )
622643 self ._unprotect_variable (varname )
623644 self ._safe_exec ("%s=%s" % (varname , new_value ))
624645 self ._protect_variable (varname )
625646
647+ def save_variable (self , varname : str ):
648+ self ._unprotect_variable (varname )
649+ self .saved_variables [
650+ varname
651+ ] = bash_env_saved .saved_state .ChangesIgnored
652+
626653 # TODO: We may restore the "export" attribute as well though it is
627654 # not currently tested in "diff_env"
628655 def write_env (self , envname : str , new_value : str , quote : bool = True ):
0 commit comments