From 911a346a003fd5ef80688caf03a0296b95efee69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 29 Sep 2023 23:01:50 +0200 Subject: transaction: Use Atomic to store/load the status Transactions save the status of each operation in a status field, however such field could be written concurrently by various operations, so we need to be sure that: - We always return the status for the current operation - We store the status in a atomic way so that other actions won't create write races In general, in a multi-thread operation one should not rely on Transaction.Error() to get info about the last operation. --- transaction.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/transaction.go b/transaction.go index 642059f..7b19f5c 100644 --- a/transaction.go +++ b/transaction.go @@ -27,6 +27,7 @@ import ( "runtime" "runtime/cgo" "strings" + "sync/atomic" "unsafe" ) @@ -129,22 +130,22 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { // //nolint:errname type Transaction struct { - handle *C.pam_handle_t - conv *C.struct_pam_conv - status C.int - c cgo.Handle + handle *C.pam_handle_t + conv *C.struct_pam_conv + lastStatus atomic.Int32 + c cgo.Handle } // transactionFinalizer cleans up the PAM handle and deletes the callback // function. func transactionFinalizer(t *Transaction) { - C.pam_end(t.handle, t.status) + C.pam_end(t.handle, C.int(t.lastStatus.Load())) t.c.Delete() } // Allows to call pam functions managing return status func (t *Transaction) handlePamStatus(cStatus C.int) error { - t.status = cStatus + t.lastStatus.Store(int32(cStatus)) if cStatus != success { return t } @@ -212,13 +213,13 @@ func start(service, user string, handler ConversationHandler, confDir string) (* err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) } if err != nil { - return nil, errors.Join(Error(t.status), err) + return nil, errors.Join(Error(t.lastStatus.Load()), err) } return t, nil } func (t *Transaction) Error() string { - return Error(t.status).Error() + return Error(t.lastStatus.Load()).Error() } // Item is a an PAM information type. @@ -363,8 +364,10 @@ func (t *Transaction) GetEnvList() (map[string]string, error) { env := make(map[string]string) p := C.pam_getenvlist(t.handle) if p == nil { - return nil, t.handlePamStatus(C.int(ErrBuf)) + t.lastStatus.Store(int32(ErrBuf)) + return nil, t } + t.lastStatus.Store(success) for q := p; *q != nil; q = next(q) { chunks := strings.SplitN(C.GoString(*q), "=", 2) if len(chunks) == 2 { -- cgit v1.2.3