summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--transaction.go36
-rw-r--r--transaction_test.go49
2 files changed, 75 insertions, 10 deletions
diff --git a/transaction.go b/transaction.go
index 2809614..52ad1be 100644
--- a/transaction.go
+++ b/transaction.go
@@ -23,7 +23,6 @@ import "C"
import (
"fmt"
- "runtime"
"runtime/cgo"
"strings"
"sync/atomic"
@@ -133,11 +132,17 @@ type Transaction struct {
c cgo.Handle
}
-// transactionFinalizer cleans up the PAM handle and deletes the callback
-// function.
-func transactionFinalizer(t *Transaction) {
- C.pam_end(t.handle, C.int(t.lastStatus.Load()))
- t.c.Delete()
+// End cleans up the PAM handle and deletes the callback function.
+// It must be called when done with the transaction.
+func (t *Transaction) End() error {
+ handle := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle)), nil)
+ if handle == nil {
+ return nil
+ }
+
+ defer t.c.Delete()
+ return t.handlePamStatus(C.pam_end((*C.pam_handle_t)(handle),
+ C.int(t.lastStatus.Load())))
}
// Allows to call pam functions managing return status
@@ -154,11 +159,19 @@ func (t *Transaction) handlePamStatus(cStatus C.int) error {
//
// All application calls to PAM begin with Start*. The returned
// transaction provides an interface to the remainder of the API.
+//
+// It's responsibility of the Transaction owner to release all the resources
+// allocated underneath by PAM by calling End() once done.
+//
+// It's not advised to End the transaction using a runtime.SetFinalizer unless
+// you're absolutely sure that your stack is multi-thread friendly (normally it
+// is not!) and using a LockOSThread/UnlockOSThread pair.
func Start(service, user string, handler ConversationHandler) (*Transaction, error) {
return start(service, user, handler, "")
}
-// StartFunc registers the handler func as a conversation handler.
+// StartFunc registers the handler func as a conversation handler and starts
+// the transaction (see Start() documentation).
func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) {
return Start(service, user, ConversationFunc(handler))
}
@@ -170,6 +183,13 @@ func StartFunc(service, user string, handler func(Style, string) (string, error)
//
// All application calls to PAM begin with Start*. The returned
// transaction provides an interface to the remainder of the API.
+//
+// It's responsibility of the Transaction owner to release all the resources
+// allocated underneath by PAM by calling End() once done.
+//
+// It's not advised to End the transaction using a runtime.SetFinalizer unless
+// you're absolutely sure that your stack is multi-thread friendly (normally it
+// is not!) and using a LockOSThread/UnlockOSThread pair.
func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) {
if !CheckPamHasStartConfdir() {
return nil, fmt.Errorf(
@@ -193,7 +213,6 @@ func start(service, user string, handler ConversationHandler, confDir string) (*
c: cgo.NewHandle(handler),
}
C.init_pam_conv(t.conv, C.uintptr_t(t.c))
- runtime.SetFinalizer(t, transactionFinalizer)
s := C.CString(service)
defer C.free(unsafe.Pointer(s))
var u *C.char
@@ -210,6 +229,7 @@ func start(service, user string, handler ConversationHandler, confDir string) (*
err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle))
}
if err != nil {
+ var _ = t.End()
return nil, err
}
return t, nil
diff --git a/transaction_test.go b/transaction_test.go
index 0e3a481..2f620a2 100644
--- a/transaction_test.go
+++ b/transaction_test.go
@@ -9,6 +9,18 @@ import (
"testing"
)
+func maybeEndTransaction(t *testing.T, tx *Transaction) {
+ t.Helper()
+
+ if tx == nil {
+ return
+ }
+ err := tx.End()
+ if err != nil {
+ t.Fatalf("end #error: %v", err)
+ }
+}
+
func TestPAM_001(t *testing.T) {
u, _ := user.Current()
if u.Uid != "0" {
@@ -18,6 +30,7 @@ func TestPAM_001(t *testing.T) {
tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) {
return p, nil
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -49,6 +62,7 @@ func TestPAM_002(t *testing.T) {
}
return "", errors.New("unexpected")
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -83,6 +97,7 @@ func TestPAM_003(t *testing.T) {
Password: "secret",
}
tx, err := Start("", "", c)
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -101,6 +116,7 @@ func TestPAM_004(t *testing.T) {
Password: "secret",
}
tx, err := Start("", "test", c)
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -118,6 +134,7 @@ func TestPAM_005(t *testing.T) {
tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) {
return "secret", nil
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -135,6 +152,7 @@ func TestPAM_006(t *testing.T) {
tx, err := StartFunc("passwd", u.Username, func(s Style, msg string) (string, error) {
return "secret", nil
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -156,6 +174,7 @@ func TestPAM_007(t *testing.T) {
tx, err := StartFunc("", "test", func(s Style, msg string) (string, error) {
return "", errors.New("Sorry, it didn't work")
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -179,6 +198,11 @@ func TestPAM_ConfDir(t *testing.T) {
Password: "wrongsecret",
}
tx, err := StartConfDir("permit-service", u.Username, c, "test-services")
+ defer func() {
+ if tx != nil {
+ _ = tx.End()
+ }
+ }()
if !CheckPamHasStartConfdir() {
if err == nil {
t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err)
@@ -200,10 +224,13 @@ func TestPAM_ConfDir_FailNoServiceOrUnsupported(t *testing.T) {
c := Credentials{
Password: "secret",
}
- _, err := StartConfDir("does-not-exists", u.Username, c, ".")
+ tx, err := StartConfDir("does-not-exists", u.Username, c, ".")
if err == nil {
t.Fatalf("authenticate #expected an error")
}
+ if tx != nil {
+ t.Fatalf("authenticate #unexpected transaction")
+ }
s := err.Error()
if len(s) == 0 {
t.Fatalf("error #expected an error message")
@@ -229,6 +256,7 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) {
}
return "", errors.New("unexpected")
}), "test-services")
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -244,6 +272,7 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) {
func TestPAM_ConfDir_Deny(t *testing.T) {
u, _ := user.Current()
tx, err := StartConfDir("deny-service", u.Username, Credentials{}, "test-services")
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -267,6 +296,7 @@ func TestPAM_ConfDir_PromptForUserName(t *testing.T) {
Password: "wrongsecret",
}
tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services")
+ defer maybeEndTransaction(t, tx)
if !CheckPamHasStartConfdir() {
if err == nil {
t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err)
@@ -289,6 +319,7 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) {
Password: "wrongsecret",
}
tx, err := StartConfDir("succeed-if-user-test", "", c, "test-services")
+ defer maybeEndTransaction(t, tx)
if !CheckPamHasStartConfdir() {
if err == nil {
t.Fatalf("start should have errored out as pam_start_confdir is not available: %v", err)
@@ -310,9 +341,13 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) {
}
func TestItem(t *testing.T) {
- tx, _ := StartFunc("passwd", "test", func(s Style, msg string) (string, error) {
+ tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) {
return "", nil
})
+ defer maybeEndTransaction(t, tx)
+ if err != nil {
+ t.Fatalf("start #error: %v", err)
+ }
s, err := tx.GetItem(Service)
if err != nil {
@@ -347,6 +382,7 @@ func TestEnv(t *testing.T) {
tx, err := StartFunc("", "", func(s Style, msg string) (string, error) {
return "", nil
})
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -489,6 +525,7 @@ func Test_Error(t *testing.T) {
}
tx, err := StartConfDir(serviceName, "user", c, servicePath)
+ defer maybeEndTransaction(t, tx)
if err != nil {
t.Fatalf("start #error: %v", err)
}
@@ -592,3 +629,11 @@ func TestFailure_009(t *testing.T) {
t.Fatalf("getenvlist #expected an error")
}
}
+
+func TestFailure_010(t *testing.T) {
+ tx := Transaction{}
+ err := tx.End()
+ if err != nil {
+ t.Fatalf("end #unexpected error %v", err)
+ }
+}