@@ -188,6 +188,7 @@ func TestMonitor_EventBridgeSuccess(t *testing.T) {
188188 h .Equals (t , result .NodeName , dnsNodeName )
189189 h .Assert (t , result .PostDrainTask != nil , "PostDrainTask should have been set" )
190190 h .Assert (t , result .PreDrainTask != nil , "PreDrainTask should have been set" )
191+ if event .ID == asgLifecycleEvent .ID { h .Assert (t , result .CancelDrainTask != nil , "CancelDrainTask should have been set" ) }
191192 err = result .PostDrainTask (result , node.Node {})
192193 h .Ok (t , err )
193194 default :
@@ -273,6 +274,7 @@ func TestMonitor_AsgDirectToSqsSuccess(t *testing.T) {
273274 h .Equals (t , result .NodeName , dnsNodeName )
274275 h .Assert (t , result .PostDrainTask != nil , "PostDrainTask should have been set" )
275276 h .Assert (t , result .PreDrainTask != nil , "PreDrainTask should have been set" )
277+ h .Assert (t , result .CancelDrainTask != nil , "CancelDrainTask should have been set" )
276278 err = result .PostDrainTask (result , node.Node {})
277279 h .Ok (t , err )
278280 default :
@@ -365,6 +367,7 @@ func TestMonitor_DrainTasks(t *testing.T) {
365367 h .Equals (st , result .NodeName , dnsNodeName )
366368 h .Assert (st , result .PostDrainTask != nil , "PostDrainTask should have been set" )
367369 h .Assert (st , result .PreDrainTask != nil , "PreDrainTask should have been set" )
370+ if event .ID == asgLifecycleEvent .ID { h .Assert (t , result .CancelDrainTask != nil , "CancelDrainTask should have been set" ) }
368371 err := result .PostDrainTask (result , node.Node {})
369372 h .Ok (st , err )
370373 })
@@ -466,6 +469,7 @@ func TestMonitor_DrainTasks_Errors(t *testing.T) {
466469 h .Equals (t , result .NodeName , dnsNodeName )
467470 h .Assert (t , result .PostDrainTask != nil , "PostDrainTask should have been set" )
468471 h .Assert (t , result .PreDrainTask != nil , "PreDrainTask should have been set" )
472+ if i == 1 { h .Assert (t , result .CancelDrainTask != nil , "CancelDrainTask should have been set" ) }
469473 err := result .PostDrainTask (result , node.Node {})
470474 h .Ok (t , err )
471475 default :
@@ -909,32 +913,39 @@ func TestMonitor_InstanceNotManaged(t *testing.T) {
909913}
910914
911915func TestSendHeartbeats_EarlyClosure (t * testing.T ) {
912- err := heartbeatTestHelper (nil , 3500 , 1 , 5 )
916+ err := heartbeatTestHelper (nil , 3500 , 1 , 5 , false )
913917 h .Ok (t , err )
914918 h .Assert (t , h .HeartbeatCallCount == 3 , "3 Heartbeat Expected, got %d" , h .HeartbeatCallCount )
915919}
916920
917921func TestSendHeartbeats_HeartbeatUntilExpire (t * testing.T ) {
918- err := heartbeatTestHelper (nil , 8000 , 1 , 5 )
922+ err := heartbeatTestHelper (nil , 8000 , 1 , 5 , false )
919923 h .Ok (t , err )
920924 h .Assert (t , h .HeartbeatCallCount == 5 , "5 Heartbeat Expected, got %d" , h .HeartbeatCallCount )
921925}
922926
923927func TestSendHeartbeats_ErrThrottlingASG (t * testing.T ) {
924928 RecordLifecycleActionHeartbeatErr := awserr .New ("Throttling" , "Rate exceeded" , nil )
925- err := heartbeatTestHelper (RecordLifecycleActionHeartbeatErr , 8000 , 1 , 6 )
929+ err := heartbeatTestHelper (RecordLifecycleActionHeartbeatErr , 8000 , 1 , 6 , false )
926930 h .Ok (t , err )
927931 h .Assert (t , h .HeartbeatCallCount == 6 , "6 Heartbeat Expected, got %d" , h .HeartbeatCallCount )
928932}
929933
930934func TestSendHeartbeats_ErrInvalidTarget (t * testing.T ) {
931935 RecordLifecycleActionHeartbeatErr := awserr .New ("ValidationError" , "No active Lifecycle Action found" , nil )
932- err := heartbeatTestHelper (RecordLifecycleActionHeartbeatErr , 6000 , 1 , 4 )
936+ err := heartbeatTestHelper (RecordLifecycleActionHeartbeatErr , 6000 , 1 , 4 , false )
933937 h .Ok (t , err )
934938 h .Assert (t , h .HeartbeatCallCount == 1 , "1 Heartbeat Expected, got %d" , h .HeartbeatCallCount )
935939}
936940
937- func heartbeatTestHelper (RecordLifecycleActionHeartbeatErr error , sleepMilliSeconds int , heartbeatInterval int , heartbeatUntil int ) error {
941+
942+ func TestSendHeartbeats_CancelHeartbeat (t * testing.T ) {
943+ err := heartbeatTestHelper (nil , 6000 , 1 , 4 , true )
944+ h .Ok (t , err )
945+ h .Assert (t , h .HeartbeatCallCount == 2 , "2 Heartbeat Expected, got %d" , h .HeartbeatCallCount )
946+ }
947+
948+ func heartbeatTestHelper (RecordLifecycleActionHeartbeatErr error , sleepMilliSeconds int , heartbeatInterval int , heartbeatUntil int , cancelDrain bool ) error {
938949 h .HeartbeatCallCount = 0
939950
940951 msg , err := getSQSMessageFromEvent (asgLifecycleEvent )
@@ -986,6 +997,16 @@ func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeco
986997 return err
987998 }
988999
1000+ if cancelDrain == true {
1001+ if result .CancelDrainTask == nil {
1002+ return fmt .Errorf ("CancelDrainTask should have been set" )
1003+ }
1004+ time .Sleep (2100 * time .Millisecond )
1005+ if err := result .CancelDrainTask (result , * testNode ); err != nil {
1006+ return err
1007+ }
1008+ }
1009+
9891010 if result .PostDrainTask == nil {
9901011 return fmt .Errorf ("PostDrainTask should have been set" )
9911012 }
0 commit comments