diff options
author | Marco Trevisan (Treviño) <[email protected]> | 2023-09-29 23:01:50 +0200 |
---|---|---|
committer | Marco Trevisan (Treviño) <[email protected]> | 2023-11-30 01:16:39 +0100 |
commit | 911a346a003fd5ef80688caf03a0296b95efee69 (patch) | |
tree | af96ea2bbb21dccbbe4807196b27eb201779e50c | |
parent | 3e4f7f5e4be10027f645e21dd2aae37cd4c580a9 (diff) |
transaction: Use Atomic to store/load the status
Transactions save the status of each operation in a status field, however
such field could be written concurrently by various operations, so we
need to be sure that:
- We always return the status for the current operation
- We store the status in a atomic way so that other actions won't
create write races
In general, in a multi-thread operation one should not rely on
Transaction.Error() to get info about the last operation.
-rw-r--r-- | transaction.go | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/transaction.go b/transaction.go index 642059f..7b19f5c 100644 --- a/transaction.go +++ b/transaction.go @@ -27,6 +27,7 @@ import ( "runtime" "runtime/cgo" "strings" + "sync/atomic" "unsafe" ) @@ -129,22 +130,22 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { // //nolint:errname type Transaction struct { - handle *C.pam_handle_t - conv *C.struct_pam_conv - status C.int - c cgo.Handle + handle *C.pam_handle_t + conv *C.struct_pam_conv + lastStatus atomic.Int32 + c cgo.Handle } // transactionFinalizer cleans up the PAM handle and deletes the callback // function. func transactionFinalizer(t *Transaction) { - C.pam_end(t.handle, t.status) + C.pam_end(t.handle, C.int(t.lastStatus.Load())) t.c.Delete() } // Allows to call pam functions managing return status func (t *Transaction) handlePamStatus(cStatus C.int) error { - t.status = cStatus + t.lastStatus.Store(int32(cStatus)) if cStatus != success { return t } @@ -212,13 +213,13 @@ func start(service, user string, handler ConversationHandler, confDir string) (* err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) } if err != nil { - return nil, errors.Join(Error(t.status), err) + return nil, errors.Join(Error(t.lastStatus.Load()), err) } return t, nil } func (t *Transaction) Error() string { - return Error(t.status).Error() + return Error(t.lastStatus.Load()).Error() } // Item is a an PAM information type. @@ -363,8 +364,10 @@ func (t *Transaction) GetEnvList() (map[string]string, error) { env := make(map[string]string) p := C.pam_getenvlist(t.handle) if p == nil { - return nil, t.handlePamStatus(C.int(ErrBuf)) + t.lastStatus.Store(int32(ErrBuf)) + return nil, t } + t.lastStatus.Store(success) for q := p; *q != nil; q = next(q) { chunks := strings.SplitN(C.GoString(*q), "=", 2) if len(chunks) == 2 { |