diff options
-rw-r--r-- | transaction.go | 9 | ||||
-rw-r--r-- | transaction_test.go | 32 |
2 files changed, 22 insertions, 19 deletions
diff --git a/transaction.go b/transaction.go index 7b19f5c..5957071 100644 --- a/transaction.go +++ b/transaction.go @@ -22,7 +22,6 @@ package pam import "C" import ( - "errors" "fmt" "runtime" "runtime/cgo" @@ -146,8 +145,8 @@ func transactionFinalizer(t *Transaction) { // Allows to call pam functions managing return status func (t *Transaction) handlePamStatus(cStatus C.int) error { t.lastStatus.Store(int32(cStatus)) - if cStatus != success { - return t + if status := Error(cStatus); status != success { + return status } return nil } @@ -213,7 +212,7 @@ func start(service, user string, handler ConversationHandler, confDir string) (* err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) } if err != nil { - return nil, errors.Join(Error(t.lastStatus.Load()), err) + return nil, err } return t, nil } @@ -365,7 +364,7 @@ func (t *Transaction) GetEnvList() (map[string]string, error) { p := C.pam_getenvlist(t.handle) if p == nil { t.lastStatus.Store(int32(ErrBuf)) - return nil, t + return nil, ErrBuf } t.lastStatus.Store(success) for q := p; *q != nil; q = next(q) { diff --git a/transaction_test.go b/transaction_test.go index 5bc858e..62b4ce1 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -167,8 +167,8 @@ func TestPAM_007(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } - if tx.Error() != ErrAuth.Error() { - t.Fatalf("error #unexpected status %v", tx.Error()) + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) } } @@ -255,8 +255,8 @@ func TestPAM_ConfDir_Deny(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } - if tx.Error() != ErrAuth.Error() { - t.Fatalf("error #unexpected status %v", tx.Error()) + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) } } @@ -304,8 +304,8 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) { if len(s) == 0 { t.Fatalf("error #expected an error message") } - if tx.Error() != ErrAuth.Error() { - t.Fatalf("error #unexpected status %v", tx.Error()) + if !errors.Is(err, ErrAuth) { + t.Fatalf("error #unexpected error %v", err) } } @@ -416,7 +416,7 @@ func Test_Error(t *testing.T) { } statuses := map[string]error{ - "success": Error(success), + "success": nil, "open_err": ErrOpen, "symbol_err": ErrSymbol, "service_err": ErrService, @@ -441,7 +441,7 @@ func Test_Error(t *testing.T) { "authtok_lock_busy": ErrAuthtokLockBusy, "authtok_disable_aging": ErrAuthtokDisableAging, "try_again": ErrTryAgain, - "ignore": Error(success), /* Ignore can't be returned */ + "ignore": nil, /* Ignore can't be returned */ "abort": ErrAbort, "authtok_expired": ErrAuthtokExpired, "module_unknown": ErrModuleUnknown, @@ -504,13 +504,17 @@ func Test_Error(t *testing.T) { err = tx.OpenSession(0) } - if tx.Error() != expected.Error() { - t.Fatalf("error #unexpected status %v", tx.Error()) + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected status %#v vs %#v", err, + expected) } - if tx.Error() == Error(success).Error() && err != nil { - t.Fatalf("error #unexpected: %v", err) - } else if tx.Error() != Error(success).Error() && err == nil { - t.Fatalf("error #expected an error message") + + if err != nil { + var status Error + if !errors.As(err, &status) || err.Error() != status.Error() { + t.Fatalf("error #unexpected status %v vs %v", err.Error(), + status.Error()) + } } }) } |