summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Trevisan (Treviño) <[email protected]>2023-09-29 23:01:50 +0200
committerMarco Trevisan (Treviño) <[email protected]>2023-11-30 01:16:39 +0100
commit911a346a003fd5ef80688caf03a0296b95efee69 (patch)
treeaf96ea2bbb21dccbbe4807196b27eb201779e50c
parent3e4f7f5e4be10027f645e21dd2aae37cd4c580a9 (diff)
transaction: Use Atomic to store/load the status
Transactions save the status of each operation in a status field, however such field could be written concurrently by various operations, so we need to be sure that: - We always return the status for the current operation - We store the status in a atomic way so that other actions won't create write races In general, in a multi-thread operation one should not rely on Transaction.Error() to get info about the last operation.
-rw-r--r--transaction.go21
1 files changed, 12 insertions, 9 deletions
diff --git a/transaction.go b/transaction.go
index 642059f..7b19f5c 100644
--- a/transaction.go
+++ b/transaction.go
@@ -27,6 +27,7 @@ import (
"runtime"
"runtime/cgo"
"strings"
+ "sync/atomic"
"unsafe"
)
@@ -129,22 +130,22 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) {
//
//nolint:errname
type Transaction struct {
- handle *C.pam_handle_t
- conv *C.struct_pam_conv
- status C.int
- c cgo.Handle
+ handle *C.pam_handle_t
+ conv *C.struct_pam_conv
+ lastStatus atomic.Int32
+ c cgo.Handle
}
// transactionFinalizer cleans up the PAM handle and deletes the callback
// function.
func transactionFinalizer(t *Transaction) {
- C.pam_end(t.handle, t.status)
+ C.pam_end(t.handle, C.int(t.lastStatus.Load()))
t.c.Delete()
}
// Allows to call pam functions managing return status
func (t *Transaction) handlePamStatus(cStatus C.int) error {
- t.status = cStatus
+ t.lastStatus.Store(int32(cStatus))
if cStatus != success {
return t
}
@@ -212,13 +213,13 @@ func start(service, user string, handler ConversationHandler, confDir string) (*
err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle))
}
if err != nil {
- return nil, errors.Join(Error(t.status), err)
+ return nil, errors.Join(Error(t.lastStatus.Load()), err)
}
return t, nil
}
func (t *Transaction) Error() string {
- return Error(t.status).Error()
+ return Error(t.lastStatus.Load()).Error()
}
// Item is a an PAM information type.
@@ -363,8 +364,10 @@ func (t *Transaction) GetEnvList() (map[string]string, error) {
env := make(map[string]string)
p := C.pam_getenvlist(t.handle)
if p == nil {
- return nil, t.handlePamStatus(C.int(ErrBuf))
+ t.lastStatus.Store(int32(ErrBuf))
+ return nil, t
}
+ t.lastStatus.Store(success)
for q := p; *q != nil; q = next(q) {
chunks := strings.SplitN(C.GoString(*q), "=", 2)
if len(chunks) == 2 {