diff options
-rw-r--r-- | transaction.go | 36 | ||||
-rw-r--r-- | transaction_test.go | 49 |
2 files changed, 75 insertions, 10 deletions
diff --git a/transaction.go b/transaction.go index 2809614..52ad1be 100644 --- a/transaction.go +++ b/transaction.go @@ -23,7 +23,6 @@ import "C" import ( "fmt" - "runtime" "runtime/cgo" "strings" "sync/atomic" @@ -133,11 +132,17 @@ type Transaction struct { c cgo.Handle } -// transactionFinalizer cleans up the PAM handle and deletes the callback -// function. -func transactionFinalizer(t *Transaction) { - C.pam_end(t.handle, C.int(t.lastStatus.Load())) - t.c.Delete() +// End cleans up the PAM handle and deletes the callback function. +// It must be called when done with the transaction. +func (t *Transaction) End() error { + handle := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle)), nil) + if handle == nil { + return nil + } + + defer t.c.Delete() + return t.handlePamStatus(C.pam_end((*C.pam_handle_t)(handle), + C.int(t.lastStatus.Load()))) } // Allows to call pam functions managing return status @@ -154,11 +159,19 @@ func (t *Transaction) handlePamStatus(cStatus C.int) error { // // All application calls to PAM begin with Start*. The returned // transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. func Start(service, user string, handler ConversationHandler) (*Transaction, error) { return start(service, user, handler, "") } -// StartFunc registers the handler func as a conversation handler. +// StartFunc registers the handler func as a conversation handler and starts +// the transaction (see Start() documentation). func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { return Start(service, user, ConversationFunc(handler)) } @@ -170,6 +183,13 @@ func StartFunc(service, user string, handler func(Style, string) (string, error) // // All application calls to PAM begin with Start*. The returned // transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { if !CheckPamHasStartConfdir() { return nil, fmt.Errorf( @@ -193,7 +213,6 @@ func start(service, user string, handler ConversationHandler, confDir string) (* c: cgo.NewHandle(handler), } C.init_pam_conv(t.conv, C.uintptr_t(t.c)) - runtime.SetFinalizer(t, transactionFinalizer) s := C.CString(service) defer C.free(unsafe.Pointer(s)) var u *C.char @@ -210,6 +229,7 @@ func start(service, user string, handler ConversationHandler, confDir string) (* err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) } if err != nil { + var _ = t.End() return nil, err } return t, nil diff --git a/transaction_test.go b/transaction_test.go index 0e3a481..2f620a2 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -9,6 +9,18 @@ import ( "testing" ) +func maybeEndTransaction(t *testing.T, tx *Transaction) { + t.Helper() + + if tx == nil { + return + } + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } +} + func TestPAM_001(t *testing.T) { u, _ := user.Current() if u.Uid != "0" { @@ -18,6 +30,7 @@ func TestPAM_001(t *testing.T) { tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) { return p, nil }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -49,6 +62,7 @@ func TestPAM_002(t *testing.T) { } return "", errors.New("unexpected") }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -83,6 +97,7 @@ func TestPAM_003(t *testing.T) { Password: "secret", } tx, err := Start("", "", c) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -101,6 +116,7 @@ func TestPAM_004(t *testing.T) { Password: "secret", } tx, err := Start("", "test", c) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -118,6 +134,7 @@ func TestPAM_005(t *testing.T) { tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "secret", nil }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -135,6 +152,7 @@ func TestPAM_006(t *testing.T) { tx, err := StartFunc("passwd", u.Username, func(s Style, msg string) (string, error) { return "secret", nil }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -156,6 +174,7 @@ func TestPAM_007(t *testing.T) { tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) { return "", errors.New("Sorry, it didn't work") }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -179,6 +198,11 @@ func TestPAM_ConfDir(t *testing.T) { Password: "wrongsecret", } tx, err := StartConfDir("permit-service", u.Username, c, "test-services") + defer func() { + if tx != nil { + _ = tx.End() + } + }() if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -200,10 +224,13 @@ func TestPAM_ConfDir_FailNoServiceOrUnsupported(t *testing.T) { c := Credentials{ Password: "secret", } - _, err := StartConfDir("does-not-exists", u.Username, c, ".") + tx, err := StartConfDir("does-not-exists", u.Username, c, ".") if err == nil { t.Fatalf("authenticate #expected an error") } + if tx != nil { + t.Fatalf("authenticate #unexpected transaction") + } s := err.Error() if len(s) == 0 { t.Fatalf("error #expected an error message") @@ -229,6 +256,7 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) { } return "", errors.New("unexpected") }), "test-services") + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -244,6 +272,7 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) { func TestPAM_ConfDir_Deny(t *testing.T) { u, _ := user.Current() tx, err := StartConfDir("deny-service", u.Username, Credentials{}, "test-services") + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -267,6 +296,7 @@ func TestPAM_ConfDir_PromptForUserName(t *testing.T) { Password: "wrongsecret", } tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services") + defer maybeEndTransaction(t, tx) if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -289,6 +319,7 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) { Password: "wrongsecret", } tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services") + defer maybeEndTransaction(t, tx) if !CheckPamHasStartConfdir() { if err == nil { t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err) @@ -310,9 +341,13 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) { } func TestItem(t *testing.T) { - tx, _ := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { + tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "", nil }) + defer maybeEndTransaction(t, tx) + if err != nil { + t.Fatalf("start #error: %v", err) + } s, err := tx.GetItem(Service) if err != nil { @@ -347,6 +382,7 @@ func TestEnv(t *testing.T) { tx, err := StartFunc("", "", func(s Style, msg string) (string, error) { return "", nil }) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -489,6 +525,7 @@ func Test_Error(t *testing.T) { } tx, err := StartConfDir(serviceName, "user", c, servicePath) + defer maybeEndTransaction(t, tx) if err != nil { t.Fatalf("start #error: %v", err) } @@ -592,3 +629,11 @@ func TestFailure_009(t *testing.T) { t.Fatalf("getenvlist #expected an error") } } + +func TestFailure_010(t *testing.T) { + tx := Transaction{} + err := tx.End() + if err != nil { + t.Fatalf("end #unexpected error %v", err) + } +} |