From c815da61f4576b62b293c98051670acff8892133 Mon Sep 17 00:00:00 2001 From: Garry Sharp <> Date: Tue, 26 Aug 2025 20:52:28 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Update=20TX=20and=20api=20?= =?UTF-8?q?code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/types/fees.go | 37 ++--- internal/verifierapi/fees.go | 111 ++++++++++++--- internal/verifierapi/verifierapi.go | 10 +- plugin/fees/load.go | 89 +++--------- plugin/fees/post_tx.go | 42 ++---- plugin/fees/transaction.go | 127 ++++++++--------- storage/db.go | 15 +- storage/postgres/fees.go | 203 ++++++++-------------------- 8 files changed, 274 insertions(+), 360 deletions(-) diff --git a/internal/types/fees.go b/internal/types/fees.go index 78c1310..6b96541 100644 --- a/internal/types/fees.go +++ b/internal/types/fees.go @@ -8,32 +8,23 @@ import ( // DB FEE Types -type FeeRunState string +type FeeBatchState string const ( - FeeRunStateDraft FeeRunState = "draft" - FeeRunStateSent FeeRunState = "sent" - FeeRunStateSuccess FeeRunState = "completed" - FeeRunStateFailed FeeRunState = "failed" + FeeBatchStateDraft FeeBatchState = "draft" + FeeBatchStateSent FeeBatchState = "sent" + FeeBatchStateSuccess FeeBatchState = "completed" + FeeBatchStateFailed FeeBatchState = "failed" ) // individual fee record in the db -type Fee struct { - ID uuid.UUID `db:"id"` - FeeRunID uuid.UUID `db:"fee_run_id"` - Amount int `db:"amount"` - CreatedAt time.Time `db:"created_at"` -} - -// fee table or fee_run_with_totals -type FeeRun struct { - ID uuid.UUID `db:"id"` - Status FeeRunState `db:"status"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` - TxHash *string `db:"tx_hash"` - PolicyID uuid.UUID `db:"policy_id"` - TotalAmount int `db:"total_amount"` - FeeCount int `db:"fee_count"` - Fees []Fee `db:"fees"` +type FeeBatch struct { + ID uuid.UUID `db:"id"` + BatchID uuid.UUID `db:"batch_id"` + PublicKey string `db:"public_key"` + Status FeeBatchState `db:"status"` + Amount uint64 `db:"amount"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + TxHash *string `db:"tx_hash"` } diff --git a/internal/verifierapi/fees.go b/internal/verifierapi/fees.go index e1dbfdc..d6aacc4 100644 --- a/internal/verifierapi/fees.go +++ b/internal/verifierapi/fees.go @@ -4,9 +4,9 @@ import ( "encoding/json" "fmt" "net/http" - "time" "github.com/google/uuid" + "github.com/vultisig/plugin/internal/types" ) type FeeDto struct { @@ -27,16 +27,47 @@ type FeeHistoryDto struct { FeesPendingCollection int `json:"fees_pending_collection" validate:"required"` // Total fees pending collection in the smallest unit, e.g., "1000000" for 0.01 VULTI } -func (v *VerifierApi) GetPublicKeysFees(ecdsaPublicKey string) (*FeeHistoryDto, error) { - response, err := v.getAuth(fmt.Sprintf("/fees/publickey/%s", ecdsaPublicKey)) +type FeeBalanceDto struct { + Balance int64 `json:"balance" validate:"required"` + PublicKey string `json:"public_key" validate:"required"` +} + +type FeeBatchCreateResponseDto struct { + PublicKey string `json:"public_key" validate:"required"` + Amount uint64 `json:"amount" validate:"required"` + BatchID uuid.UUID `json:"batch_id" validate:"required"` +} + +type FeeBatchUpdateRequestResponseDto struct { + PublicKey string `json:"public_key" validate:"required"` + BatchID uuid.UUID `json:"batch_id" validate:"required"` + TxHash string `json:"tx_hash" validate:"required"` + Status types.FeeBatchState `json:"status" validate:"required"` +} + +func (v *VerifierApi) CreateFeeBatch(publicKey string) (*FeeBatchCreateResponseDto, error) { + response, err := v.postAuth("/fees/batch", map[string]interface{}{ + "public_key": publicKey, + }) + if err != nil { + return nil, fmt.Errorf("failed to create fee batch: %w", err) + } + defer response.Body.Close() + + var feeBatchResponse APIResponse[FeeBatchCreateResponseDto] + if err := json.NewDecoder(response.Body).Decode(&feeBatchResponse); err != nil { + return nil, fmt.Errorf("failed to decode fee batch response: %w", err) + } + + return &feeBatchResponse.Data, nil +} + +func (v *VerifierApi) GetFeeHistory(ecdsaPublicKey string) (*FeeHistoryDto, error) { + response, err := v.getAuth(fmt.Sprintf("/fees/history/%s", ecdsaPublicKey)) if err != nil { return nil, fmt.Errorf("failed to get public key fees: %w", err) } - defer func() { - if err := response.Body.Close(); err != nil { - v.logger.WithError(err).Error("Failed to close response body") - } - }() + defer response.Body.Close() if response.StatusCode == http.StatusNotFound { return nil, fmt.Errorf("public key not found") } @@ -57,28 +88,64 @@ func (v *VerifierApi) GetPublicKeysFees(ecdsaPublicKey string) (*FeeHistoryDto, return &feeHistory.Data, nil } -func (v *VerifierApi) MarkFeeAsCollected(txHash string, collectedAt time.Time, feeIds ...uuid.UUID) error { +func (v *VerifierApi) GetFeeBalance(ecdsaPublicKey string) (*FeeBalanceDto, error) { + response, err := v.getAuth(fmt.Sprintf("/fees/balance/%s", ecdsaPublicKey)) + if err != nil { + return nil, fmt.Errorf("failed to get public key fees: %w", err) + } + defer response.Body.Close() + + if response.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("public key not found") + } + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get public key fees, status code: %d", response.StatusCode) + } - var body = struct { - IDs []uuid.UUID `json:"ids"` - TxHash string `json:"tx_hash"` - CollectedAt time.Time `json:"collected_at"` - }{ - IDs: feeIds, - TxHash: txHash, - CollectedAt: collectedAt, + var feeBalance APIResponse[FeeBalanceDto] + if err := json.NewDecoder(response.Body).Decode(&feeBalance); err != nil { + return nil, fmt.Errorf("failed to decode public key fees response: %w", err) } - url := "/fees/collected" - response, err := v.postAuth(url, body) + return &feeBalance.Data, nil +} + +func (v *VerifierApi) CreateFeeCredit(id uuid.UUID, amount int64, publicKey string) error { + response, err := v.postAuth("/fees/credit", map[string]interface{}{ + "id": id, + "amount": amount, + "public_key": publicKey, + }) if err != nil { - return fmt.Errorf("failed to mark fee as collected: %w", err) + return fmt.Errorf("failed to create fee credit: %w", err) } defer response.Body.Close() - if response.StatusCode != http.StatusOK { - return fmt.Errorf("failed to mark fee as collected, status code: %d", response.StatusCode) + return nil +} + +func (v *VerifierApi) RevertFeeCredit(txHash string, batchId uuid.UUID) error { + response, err := v.postAuth(fmt.Sprintf("/fees/revert/%s", batchId), struct{}{}) + if err != nil { + return fmt.Errorf("failed to revert fee credit: %w", err) } + defer response.Body.Close() return nil } + +func (v *VerifierApi) UpdateFeeBatchTxHash(publickey string, batchId uuid.UUID, hash string) (*FeeBatchCreateResponseDto, error) { + response, err := v.putAuth("/fees/batch", FeeBatchUpdateRequestResponseDto{ + PublicKey: publickey, + BatchID: batchId, + TxHash: hash, + Status: types.FeeBatchStateSent, + }) + if err != nil { + return nil, fmt.Errorf("failed to get fee batch: %w", err) + } + defer response.Body.Close() + + return nil, nil +} diff --git a/internal/verifierapi/verifierapi.go b/internal/verifierapi/verifierapi.go index c0d540d..4445a06 100644 --- a/internal/verifierapi/verifierapi.go +++ b/internal/verifierapi/verifierapi.go @@ -53,12 +53,20 @@ func (v *VerifierApi) getAuth(endpoint string) (*http.Response, error) { } func (v *VerifierApi) postAuth(endpoint string, body any) (*http.Response, error) { + return v.bodyRequest(endpoint, body, http.MethodPost) +} + +func (v *VerifierApi) putAuth(endpoint string, body any) (*http.Response, error) { + return v.bodyRequest(endpoint, body, http.MethodPut) +} + +func (v *VerifierApi) bodyRequest(endpoint string, body any, httpMethod string) (*http.Response, error) { jsonBody, err := json.Marshal(body) if err != nil { return nil, err } - request, err := http.NewRequest(http.MethodPost, v.url+endpoint, bytes.NewBuffer(jsonBody)) + request, err := http.NewRequest(httpMethod, v.url+endpoint, bytes.NewBuffer(jsonBody)) if err != nil { return nil, err } diff --git a/plugin/fees/load.go b/plugin/fees/load.go index beb622b..be11601 100644 --- a/plugin/fees/load.go +++ b/plugin/fees/load.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "github.com/hibiken/asynq" - "github.com/sirupsen/logrus" "github.com/vultisig/plugin/internal/types" vtypes "github.com/vultisig/verifier/types" "golang.org/x/sync/errgroup" @@ -46,7 +45,11 @@ func (fp *FeePlugin) LoadFees(ctx context.Context, task *asynq.Task) error { return fmt.Errorf("failed to acquire semaphore: %w", err) } defer sem.Release(1) - return fp.executeFeeLoading(ctx, feePolicy) + err := fp.executeFeeLoading(ctx, feePolicy) + if err != nil { + fp.logger.WithError(err).WithField("public_key", feePolicy.PublicKey).Error("Failed to execute fee loading") + } + return err }) } @@ -60,80 +63,28 @@ func (fp *FeePlugin) LoadFees(ctx context.Context, task *asynq.Task) error { func (fp *FeePlugin) executeFeeLoading(ctx context.Context, feePolicy vtypes.PluginPolicy) error { // Get list of fees from the verifier connected to the fee policy - feesResponse, err := fp.verifierApi.GetPublicKeysFees(feePolicy.PublicKey) + batch, err := fp.verifierApi.CreateFeeBatch(feePolicy.PublicKey) if err != nil { return fmt.Errorf("failed to get plugin policy fees: %w", err) } - // Early return if no fees to collect - if feesResponse.FeesPendingCollection <= 0 { - fp.logger.WithField("publicKey", feePolicy.PublicKey).Info("No fees pending collection") - return nil + if err != nil { + return fmt.Errorf("failed to create fee batch: %w", err) } - // If fees are greater than 0, we need to collect them - fp.logger.WithFields(logrus.Fields{ - "publicKey": feePolicy.PublicKey, - }).Info("Fees pending collection: ", feesResponse.FeesPendingCollection) - - checkAmount := 0 - for _, fee := range feesResponse.Fees { - if !fee.Collected { - checkAmount += fee.Amount - } - } - if checkAmount != feesResponse.FeesPendingCollection { - return fmt.Errorf("fees pending collection amount does not match the sum of the fees") + if batch.Amount == 0 || batch.BatchID == uuid.Nil { + fp.logger.WithField("public_key", feePolicy.PublicKey).Info("No fees to load") + return nil } - for _, fee := range feesResponse.Fees { - if !fee.Collected { - - // Check if the fee has already been loaded and added to a fee run, if so, skip it - existingFee, err := fp.db.GetFees(ctx, fee.ID) - if err != nil { - return fmt.Errorf("failed to get fee: %w", err) - } - if len(existingFee) > 0 { - fp.logger.WithFields(logrus.Fields{ - "publicKey": feePolicy.PublicKey, - "feeId": fee.ID, - "runId": existingFee[0].FeeRunID, - }).Info("Fee already added to a fee run") - continue - } - - // If the fee hasn't been loaded, look for a draft run and add it to it - run, err := fp.db.GetPendingFeeRun(ctx, feePolicy.ID) - if err != nil { - return fmt.Errorf("failed to get pending fee run: %w", err) - } + _, err = fp.db.CreateFeeBatch(ctx, nil, types.FeeBatch{ + ID: uuid.New(), + BatchID: batch.BatchID, + PublicKey: feePolicy.PublicKey, + Status: types.FeeBatchStateDraft, + TxHash: nil, + Amount: uint64(batch.Amount), + }) - // If no draft run is found, create a new one and add the fee to it - if run == nil { - run, err = fp.db.CreateFeeRun(ctx, feePolicy.ID, types.FeeRunStateDraft, fee) - if err != nil { - return fmt.Errorf("failed to create fee run: %w", err) - } - fp.logger.WithFields(logrus.Fields{ - "publicKey": feePolicy.PublicKey, - "feeIds": []uuid.UUID{fee.ID}, - "runId": run.ID, - }).Info("Fee run created") - - // If a draft run is found, add the fee to it - } else { - if err := fp.db.CreateFee(ctx, run.ID, fee); err != nil { - return fmt.Errorf("failed to create fee: %w", err) - } - fp.logger.WithFields(logrus.Fields{ - "publicKey": feePolicy.PublicKey, - "feeIds": []uuid.UUID{fee.ID}, - "runId": run.ID, - }).Info("Fee added to fee run") - } - } - } - - return nil + return err } diff --git a/plugin/fees/post_tx.go b/plugin/fees/post_tx.go index 718e9fa..2af2262 100644 --- a/plugin/fees/post_tx.go +++ b/plugin/fees/post_tx.go @@ -7,7 +7,6 @@ import ( "github.com/ethereum/go-ethereum" ecommon "github.com/ethereum/go-ethereum/common" - "github.com/google/uuid" "github.com/hibiken/asynq" "github.com/sirupsen/logrus" "github.com/vultisig/plugin/internal/types" @@ -19,7 +18,7 @@ import ( func (fp *FeePlugin) HandlePostTx(ctx context.Context, task *asynq.Task) error { // Get a list of all fee runs that are in a sent state - runs, err := fp.db.GetAllFeeRuns(ctx, types.FeeRunStateSent) + batches, err := fp.db.GetFeeBatchByStatus(ctx, types.FeeBatchStateSent) if err != nil { fp.logger.WithError(err).Error("failed to get fee runs") return fmt.Errorf("failed to get fee runs: %w", err) @@ -34,19 +33,16 @@ func (fp *FeePlugin) HandlePostTx(ctx context.Context, task *asynq.Task) error { sem := semaphore.NewWeighted(int64(fp.config.Jobs.Post.MaxConcurrentJobs)) var wg sync.WaitGroup var eg errgroup.Group - for _, run := range runs { + for _, batch := range batches { wg.Add(1) - run = run + feeBatch := batch eg.Go(func() error { defer wg.Done() if err := sem.Acquire(ctx, 1); err != nil { return fmt.Errorf("failed to acquire semaphore: %w", err) } defer sem.Release(1) - if run.TxHash == nil || run.Status == types.FeeRunStateDraft { - return nil - } - return fp.updateStatus(ctx, run, currentBlock) + return fp.updateStatus(ctx, feeBatch, currentBlock) }) } wg.Wait() @@ -57,43 +53,35 @@ func (fp *FeePlugin) HandlePostTx(ctx context.Context, task *asynq.Task) error { return nil } -func (fp *FeePlugin) updateStatus(ctx context.Context, run types.FeeRun, currentBlock uint64) error { - if run.TxHash == nil || run.Status == types.FeeRunStateDraft { +func (fp *FeePlugin) updateStatus(ctx context.Context, batch types.FeeBatch, currentBlock uint64) error { + if batch.TxHash == nil || batch.Status == types.FeeBatchStateDraft { return nil } - fp.logger.WithFields(logrus.Fields{"run_id": run.ID}).Info("Beginning status check/update") - hash := ecommon.HexToHash(*run.TxHash) + fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("Beginning status check/update") + hash := ecommon.HexToHash(*batch.TxHash) receipt, err := fp.ethClient.TransactionReceipt(ctx, hash) if err == ethereum.NotFound { // TODO rebroadcast logic - fp.logger.WithFields(logrus.Fields{"run_id": run.ID}).Info("tx not found on chain, rebroadcasting") + fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx not found on chain, rebroadcasting") return nil } if receipt.Status == 1 { if currentBlock > receipt.BlockNumber.Uint64()+fp.config.Jobs.Post.SuccessConfirmations { - fp.logger.WithFields(logrus.Fields{"run_id": run.ID}).Info("tx successful, setting to success") - - ids := []uuid.UUID{} - for _, fee := range run.Fees { - ids = append(ids, fee.ID) - } - - if err = fp.verifierApi.MarkFeeAsCollected(*run.TxHash, run.CreatedAt, ids...); err != nil { - return fmt.Errorf("failed to mark fee as collected on verifier: %w", err) - } + fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx successful, setting to success") // This is semi critical code as it could create a state mismatch between the verifier and the database. - if err = fp.db.SetFeeRunSuccess(ctx, run.ID); err != nil { - return fmt.Errorf("failed to set fee run success: %w", err) + if err = fp.db.SetFeeBatchStatus(ctx, nil, batch.BatchID, types.FeeBatchStateSuccess); err != nil { + return fmt.Errorf("failed to set fee batch success: %w", err) } } else { - fp.logger.WithFields(logrus.Fields{"run_id": run.ID}).Info("tx successful, but not enough confirmations, waiting for more") + fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx successful, but not enough confirmations, waiting for more") return nil } } else { // TODO failed tx logic - fp.logger.WithFields(logrus.Fields{"run_id": run.ID}).Info("tx failed, setting to failed") + fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx failed, setting to failed") + fp.verifierApi.RevertFeeCredit(*batch.TxHash, batch.BatchID) return nil } return nil diff --git a/plugin/fees/transaction.go b/plugin/fees/transaction.go index d42db0c..ad55fbc 100644 --- a/plugin/fees/transaction.go +++ b/plugin/fees/transaction.go @@ -7,13 +7,12 @@ import ( "errors" "fmt" "math/big" + "runtime/debug" "strconv" - "sync" gcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" etypes "github.com/ethereum/go-ethereum/core/types" - "github.com/google/uuid" "github.com/hibiken/asynq" "github.com/sirupsen/logrus" "github.com/vultisig/mobile-tss-lib/tss" @@ -46,47 +45,55 @@ func (fp *FeePlugin) HandleTransactions(ctx context.Context, task *asynq.Task) e }() fp.logger.Info("Getting all fee runs") - runs, err := fp.db.GetAllFeeRuns(ctx) + feeBatches, err := fp.db.GetFeeBatchByStatus(ctx, types.FeeBatchStateDraft) if err != nil { fp.logger.WithError(err).Error("Failed to get fee runs") return fmt.Errorf("failed to get fee runs: %w", err) } sem := semaphore.NewWeighted(int64(fp.config.Jobs.Transact.MaxConcurrentJobs)) - var wg sync.WaitGroup var eg errgroup.Group - for _, run := range runs { - run = run - eg.Go(func() error { - //TODO also check failed runs - if run.Status != types.FeeRunStateDraft { - return nil - } - if run.TxHash != nil { - return nil - } + for _, batch := range feeBatches { + feeBatch := batch + eg.Go(func() error { + // Add panic recovery for this goroutine + defer func() { + if r := recover(); r != nil { + fp.logger.WithFields(logrus.Fields{ + "public_key": feeBatch.PublicKey, + "panic": r, + }).Error("Recovered from panic in fee transaction processing") + debug.PrintStack() + } + }() + + fp.logger.WithFields(logrus.Fields{"public_key": feeBatch.PublicKey}).Info("Processing fee policy") - if run.FeeCount == 0 || run.TotalAmount == 0 { - return nil + if err := sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) } + defer sem.Release(1) - fp.logger.WithFields(logrus.Fields{"runId": run.ID}).Info("Processing fee run") - feePolicy, err := fp.db.GetPluginPolicy(ctx, run.PolicyID) + policies, err := fp.db.GetPluginPolicies(ctx, feeBatch.PublicKey, vtypes.PluginVultisigFees_feee, true) if err != nil { - return fmt.Errorf("failed to get fee policy: %w", err) + fp.logger.WithError(err).WithFields(logrus.Fields{ + "public_key": feeBatch.PublicKey, + }).Error("Failed to get plugin policy") + return err } - wg.Add(1) - fp.logger.WithFields(logrus.Fields{"runId": run.ID, "policyId": run.PolicyID}).Info("Retrieved fee policy") - - defer wg.Done() - if err := sem.Acquire(ctx, 1); err != nil { - return fmt.Errorf("failed to acquire semaphore: %w", err) + if len(policies) != 1 { + fp.logger.WithFields(logrus.Fields{ + "public_key": feeBatch.PublicKey, + }).Error(fmt.Sprintf("Expected 1 plugin policy, got %d", len(policies))) + return fmt.Errorf("expected 1 plugin policy, got %d", len(policies)) } - defer sem.Release(1) - if err := fp.executeFeeTransaction(ctx, run, *feePolicy); err != nil { + + policy := policies[0] + + if err := fp.executeFeeTransaction(ctx, feeBatch, policy); err != nil { fp.logger.WithError(err).WithFields(logrus.Fields{ - "runId": run.ID, + "public_key": feeBatch.PublicKey, }).Error("Failed to execute fee transaction") return err } @@ -94,7 +101,6 @@ func (fp *FeePlugin) HandleTransactions(ctx context.Context, task *asynq.Task) e }) } - wg.Wait() if err := eg.Wait(); err != nil { return fmt.Errorf("failed to execute fee transaction: %w", err) } @@ -102,15 +108,14 @@ func (fp *FeePlugin) HandleTransactions(ctx context.Context, task *asynq.Task) e return nil } -func (fp *FeePlugin) executeFeeTransaction(ctx context.Context, run types.FeeRun, feePolicy vtypes.PluginPolicy) error { +func (fp *FeePlugin) executeFeeTransaction(ctx context.Context, feeBatch types.FeeBatch, feePolicy vtypes.PluginPolicy) error { fp.logger.WithFields(logrus.Fields{ - "runId": run.ID, - "policyId": feePolicy.ID, - }).Info("Checking if fee run policy id matches fee policy id") - if run.PolicyID != feePolicy.ID { - return fmt.Errorf("fee run policy id does not match fee policy id") - } + "amount": feeBatch.Amount, + "publicKey": feePolicy.PublicKey, + "policyId": feePolicy.ID, + "batchId": feeBatch.BatchID, + }).Info("Executing fee transaction") // Get a vault and sign the transactions fp.logger.WithFields(logrus.Fields{ @@ -121,11 +126,7 @@ func (fp *FeePlugin) executeFeeTransaction(ctx context.Context, run types.FeeRun return fmt.Errorf("failed to get vault: %w", err) } - // Propose the transactions - fp.logger.WithFields(logrus.Fields{ - "publicKey": feePolicy.PublicKey, - }).Info("Proposing transactions") - keySignRequests, err := fp.proposeTransactions(ctx, feePolicy, run) + keySignRequests, err := fp.proposeTransactions(ctx, feePolicy, feeBatch.Amount) if err != nil { return fmt.Errorf("failed to propose transactions: %w", err) } @@ -134,7 +135,7 @@ func (fp *FeePlugin) executeFeeTransaction(ctx context.Context, run types.FeeRun }).Info("Key sign requests proposed") for _, keySignRequest := range keySignRequests { req := keySignRequest - if err := fp.initSign(ctx, req, feePolicy, run.ID); err != nil { + if err := fp.initSign(ctx, req, feePolicy, feeBatch); err != nil { return fmt.Errorf("failed to init sign: %w", err) } } @@ -146,9 +147,11 @@ func (fp *FeePlugin) initSign( ctx context.Context, req vtypes.PluginKeysignRequest, pluginPolicy vtypes.PluginPolicy, - runId uuid.UUID, + feeBatch types.FeeBatch, ) error { + fmt.Printf("Init sign: %+v\n", req) + sigs, err := fp.signer.Sign(ctx, req) if err != nil { fp.logger.WithError(err).Error("Keysign failed") @@ -186,12 +189,19 @@ func (fp *FeePlugin) initSign( return fmt.Errorf("failed to decode tx: %w", err) } + if err := fp.db.SetFeeBatchSent(ctx, txHash.Hash().Hex(), feeBatch.BatchID); err != nil { + return fmt.Errorf("failed to set fee batch sent: %w", err) + } + + fp.verifierApi.UpdateFeeBatchTxHash(pluginPolicy.PublicKey, feeBatch.BatchID, txHash.Hash().Hex()) + fp.logger.WithFields(logrus.Fields{ "tx_hash": txHash.Hash().Hex(), "tx_to": erc20tx.to.Hex(), "tx_amount": erc20tx.amount.String(), "tx_token": erc20tx.token.Hex(), "public_key": pluginPolicy.PublicKey, + "batch_id": feeBatch.BatchID, }).Info("fee collection transaction") tx, err := fp.eth.Send(ctx, txBytes, r, s, v) @@ -200,22 +210,19 @@ func (fp *FeePlugin) initSign( return fmt.Errorf("failed to send transaction: %w", err) } - // This is exceptionally important, as if it errors, the transaction will internally be recorded as draft, even after it's been broadcasted - if err := fp.db.SetFeeRunSent(ctx, runId, tx.Hash().Hex()); err != nil { //TODO pass the real tx id - return fmt.Errorf("failed to set fee run sent: %w", err) - } - - // Log successful transaction broadcast - fp.logger.WithField("hash", tx.Hash().Hex()).Info("fee collection transaction successfully broadcasted") + fp.logger.WithFields(logrus.Fields{ + "tx_hash": tx.Hash().Hex(), + "tx_to": erc20tx.to.Hex(), + "tx_amount": erc20tx.amount.String(), + "tx_token": erc20tx.token.Hex(), + "public_key": pluginPolicy.PublicKey, + "batch_id": feeBatch.BatchID, + }).Info("fee collection transaction successfully broadcasted") return nil } -func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.PluginPolicy, run types.FeeRun) ([]vtypes.PluginKeysignRequest, error) { - - if policy.ID != run.PolicyID { - return nil, fmt.Errorf("policy id does not match run policy id") - } +func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.PluginPolicy, amount uint64) ([]vtypes.PluginKeysignRequest, error) { vault, err := common.GetVaultFromPolicy(fp.vaultStorage, policy, fp.encryptionSecret) if err != nil { @@ -286,8 +293,6 @@ func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.Plug return nil, fmt.Errorf("token address does not match usdc address") } - amount := run.TotalAmount - tx, err := fp.eth.MakeAnyTransfer(ctx, gcommon.HexToAddress(ethAddress), gcommon.HexToAddress(recipient), @@ -318,12 +323,10 @@ func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.Plug HashFunction: vtypes.HashFunction_SHA256, }, }, - SessionID: "", - HexEncryptionKey: "", - PolicyID: policy.ID, - PluginID: policy.PluginID.String(), + PolicyID: policy.ID, + PluginID: policy.PluginID.String(), }, - Transaction: txHex, + Transaction: base64.StdEncoding.EncodeToString(tx), } txs = append(txs, signRequest) diff --git a/storage/db.go b/storage/db.go index 8a98027..17a44da 100644 --- a/storage/db.go +++ b/storage/db.go @@ -9,7 +9,6 @@ import ( vtypes "github.com/vultisig/verifier/types" "github.com/vultisig/plugin/internal/types" - "github.com/vultisig/plugin/internal/verifierapi" ) type DatabaseStorage interface { @@ -22,14 +21,12 @@ type DatabaseStorage interface { InsertPluginPolicyTx(ctx context.Context, dbTx pgx.Tx, policy vtypes.PluginPolicy) (*vtypes.PluginPolicy, error) UpdatePluginPolicyTx(ctx context.Context, dbTx pgx.Tx, policy vtypes.PluginPolicy) (*vtypes.PluginPolicy, error) - CreateFeeRun(ctx context.Context, policyId uuid.UUID, state types.FeeRunState, fees ...verifierapi.FeeDto) (*types.FeeRun, error) - SetFeeRunSent(ctx context.Context, runId uuid.UUID, txHash string) error - SetFeeRunSuccess(ctx context.Context, runId uuid.UUID) error - GetAllFeeRuns(ctx context.Context, statuses ...types.FeeRunState) ([]types.FeeRun, error) // If no statuses are provided, all fee runs are returned. - GetFees(ctx context.Context, feeIds ...uuid.UUID) ([]types.Fee, error) - GetPendingFeeRun(ctx context.Context, policyId uuid.UUID) (*types.FeeRun, error) - CreateFee(ctx context.Context, runId uuid.UUID, fee verifierapi.FeeDto) error - GetFeeRuns(ctx context.Context, state types.FeeRunState) ([]types.FeeRun, error) + CreateFeeBatch(ctx context.Context, tx *pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) + SetFeeBatchTxHash(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, txHash string) error + SetFeeBatchStatus(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error + GetFeeBatch(ctx context.Context, batchIDs ...uuid.UUID) ([]types.FeeBatch, error) + GetFeeBatchByStatus(ctx context.Context, status types.FeeBatchState) ([]types.FeeBatch, error) + SetFeeBatchSent(ctx context.Context, txHash string, batchId uuid.UUID) error Pool() *pgxpool.Pool } diff --git a/storage/postgres/fees.go b/storage/postgres/fees.go index afec256..937416d 100644 --- a/storage/postgres/fees.go +++ b/storage/postgres/fees.go @@ -2,198 +2,107 @@ package postgres import ( "context" - "errors" - "fmt" "github.com/google/uuid" - "github.com/jackc/pgx/v5" + "github.com/vultisig/plugin/internal/types" - "github.com/vultisig/plugin/internal/verifierapi" ) -func (p *PostgresBackend) CreateFeeRun(ctx context.Context, policyId uuid.UUID, state types.FeeRunState, fees ...verifierapi.FeeDto) (*types.FeeRun, error) { - // Check policy id is valid - query := `select plugin_id from plugin_policies where id = $1` - policyrows := p.pool.QueryRow(ctx, query, policyId) - var pluginId string - err := policyrows.Scan(&pluginId) - if err != nil { - return nil, err - } - if pluginId != "vultisig-fees-feee" { - return nil, errors.New("plugin id not found or not vultisig-fees-feee") - } - - tx, err := p.pool.Begin(ctx) - if err != nil { - return nil, err - } - defer tx.Rollback(ctx) - runId := uuid.New() - _, err = tx.Exec(ctx, `insert into fee_run (id, status, policy_id) values ($1, $2, $3) returning id`, runId, state, policyId) - if err != nil { - return nil, fmt.Errorf("failed to insert fee run: %w", err) - } - - for _, fee := range fees { - _, err = tx.Exec(ctx, `insert into fee (id, fee_run_id, amount) values ($1, $2, $3)`, fee.ID, runId, fee.Amount) +func (p *PostgresBackend) CreateFeeBatch(ctx context.Context, tx *pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) { + if tx == nil { + _tx, err := p.pool.Begin(ctx) if err != nil { - return nil, fmt.Errorf("failed to insert fee: %w", err) + return nil, err + } + tx = &_tx + defer func() { + if err != nil { + (*tx).Rollback(ctx) + } + (*tx).Commit(ctx) + }() + } + + query := `insert into fee_batch (id, batch_id, public_key, status, amount, tx_hash) values ($1, $2, $3, $4, $5, $6) returning *` + feeBatches := make([]types.FeeBatch, 0, len(batches)) + for _, batch := range batches { + rows, err := (*tx).Query(ctx, query, batch.ID, batch.BatchID, batch.PublicKey, batch.Status, batch.Amount, batch.TxHash) + if err != nil { + return nil, err } - } - err = tx.Commit(ctx) - if err != nil { - return nil, fmt.Errorf("failed to commit transaction: %w", err) - } + insertedBatch, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[types.FeeBatch]) + if err != nil { + rows.Close() + return nil, err + } + rows.Close() - var run types.FeeRun - err = p.pool.QueryRow(ctx, `select id, status, created_at, updated_at, tx_hash, policy_id, total_amount, fee_count from fee_run_with_totals where id = $1`, runId).Scan(&run.ID, &run.Status, &run.CreatedAt, &run.UpdatedAt, &run.TxHash, &run.PolicyID, &run.TotalAmount, &run.FeeCount) - if err != nil { - return nil, fmt.Errorf("failed to get fee run (post commit): %s", err) + feeBatches = append(feeBatches, insertedBatch) } - return &run, nil + return feeBatches, nil } -func (p *PostgresBackend) SetFeeRunSent(ctx context.Context, runId uuid.UUID, txHash string) error { - _, err := p.pool.Exec(ctx, `update fee_run set status = $1, tx_hash = $2 where id = $3`, types.FeeRunStateSent, txHash, runId) +func (p *PostgresBackend) SetFeeBatchTxHash(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, txHash string) error { + query := `update fee_batch set tx_hash = $1 where batch_id = $2` + _, err := (*tx).Exec(ctx, query, txHash, batchId) if err != nil { - return fmt.Errorf("failed to update fee run: %w", err) + return err } return nil } -func (p *PostgresBackend) SetFeeRunSuccess(ctx context.Context, runId uuid.UUID) error { - _, err := p.pool.Exec(ctx, `update fee_run set status = $1 where id = $2`, types.FeeRunStateSuccess, runId) +func (p *PostgresBackend) SetFeeBatchStatus(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error { + query := `update fee_batch set status = $1 where batch_id = $2` + _, err := (*tx).Exec(ctx, query, status, batchId) if err != nil { - return fmt.Errorf("failed to update fee run: %w", err) + return err } return nil } -func (p *PostgresBackend) GetAllFeeRuns(ctx context.Context, statuses ...types.FeeRunState) ([]types.FeeRun, error) { - - var rows pgx.Rows - var err error - - if len(statuses) == 0 { - query := `select id, status, created_at, updated_at, tx_hash, policy_id, total_amount, fee_count from fee_run_with_totals` - rows, err = p.pool.Query(ctx, query) - } else { - query := `select id, status, created_at, updated_at, tx_hash, policy_id, total_amount, fee_count from fee_run_with_totals where status = ANY($1)` - rows, err = p.pool.Query(ctx, query, statuses) - } - +func (p *PostgresBackend) GetFeeBatch(ctx context.Context, batchIDs ...uuid.UUID) ([]types.FeeBatch, error) { + query := `select * from fee_batch where id = $1` + rows, err := p.pool.Query(ctx, query, batchIDs) if err != nil { return nil, err } defer rows.Close() - rm := make(map[uuid.UUID]types.FeeRun) - for rows.Next() { - var run types.FeeRun - err := rows.Scan(&run.ID, &run.Status, &run.CreatedAt, &run.UpdatedAt, &run.TxHash, &run.PolicyID, &run.TotalAmount, &run.FeeCount) - if err != nil { - return nil, err - } - rm[run.ID] = run - } - runIds := make([]uuid.UUID, 0, len(rm)) - for runId := range rm { - runIds = append(runIds, runId) - } - - feesQuery := `select id, fee_run_id, amount from fee where fee_run_id = ANY($1)` - feesRows, err := p.pool.Query(ctx, feesQuery, runIds) - if err != nil { - return nil, err - } - defer feesRows.Close() - for feesRows.Next() { - var fee types.Fee - err := feesRows.Scan(&fee.ID, &fee.FeeRunID, &fee.Amount) + feeBatches := make([]types.FeeBatch, 0, len(batchIDs)) + for rows.Next() { + feebatch, err := pgx.RowToStructByName[types.FeeBatch](rows) if err != nil { return nil, err } - if run, ok := rm[fee.FeeRunID]; !ok { - return nil, fmt.Errorf("fee run not found: %s", fee.FeeRunID) - } else { - run.Fees = append(run.Fees, fee) - rm[fee.FeeRunID] = run - } - } - - runs := make([]types.FeeRun, 0, len(rm)) - for _, run := range rm { - runs = append(runs, run) + feeBatches = append(feeBatches, feebatch) } - return runs, nil + return feeBatches, nil } -func (p *PostgresBackend) GetFees(ctx context.Context, feeIds ...uuid.UUID) ([]types.Fee, error) { - query := `select id, fee_run_id, amount from fee where id = ANY($1)` - rows, err := p.pool.Query(ctx, query, feeIds) +func (p *PostgresBackend) GetFeeBatchByStatus(ctx context.Context, status types.FeeBatchState) ([]types.FeeBatch, error) { + query := `select * from fee_batch where status = $1` + rows, err := p.pool.Query(ctx, query, status) if err != nil { return nil, err } defer rows.Close() - fees := []types.Fee{} + feeBatches := []types.FeeBatch{} for rows.Next() { - var fee types.Fee - err := rows.Scan(&fee.ID, &fee.FeeRunID, &fee.Amount) + feebatch, err := pgx.RowToStructByName[types.FeeBatch](rows) if err != nil { return nil, err } - fees = append(fees, fee) + feeBatches = append(feeBatches, feebatch) } - return fees, nil + return feeBatches, nil } -func (p *PostgresBackend) GetPendingFeeRun(ctx context.Context, policyId uuid.UUID) (*types.FeeRun, error) { - query := `select id, status, created_at, updated_at, tx_hash, policy_id, total_amount, fee_count from fee_run_with_totals where status = $1 and policy_id = $2 order by created_at desc limit 1` - rows, err := p.pool.Query(ctx, query, types.FeeRunStateDraft, policyId) - if err != nil { - return nil, err - } - defer rows.Close() - if !rows.Next() { - return nil, nil - } - var run types.FeeRun - err = rows.Scan(&run.ID, &run.Status, &run.CreatedAt, &run.UpdatedAt, &run.TxHash, &run.PolicyID, &run.TotalAmount, &run.FeeCount) - if err != nil { - return nil, err - } - return &run, nil -} - -func (p *PostgresBackend) CreateFee(ctx context.Context, runId uuid.UUID, fee verifierapi.FeeDto) error { - _, err := p.pool.Exec(ctx, `insert into fee (id, fee_run_id, amount) values ($1, $2, $3)`, fee.ID, runId, fee.Amount) - if err != nil { - return fmt.Errorf("failed to insert fee: %w", err) - } - return nil -} - -func (p *PostgresBackend) GetFeeRuns(ctx context.Context, state types.FeeRunState) ([]types.FeeRun, error) { - query := `select id, status, created_at, updated_at, tx_hash, policy_id, total_amount, fee_count from fee_run_with_totals where status = $1` - rows, err := p.pool.Query(ctx, query, state) - if err != nil { - return nil, err - } - defer rows.Close() - runs := []types.FeeRun{} - - for rows.Next() { - var run types.FeeRun - err := rows.Scan(&run.ID, &run.Status, &run.CreatedAt, &run.UpdatedAt, &run.TxHash, &run.PolicyID, &run.TotalAmount, &run.FeeCount) - if err != nil { - return nil, err - } - runs = append(runs, run) - } - return runs, nil +func (p *PostgresBackend) SetFeeBatchSent(ctx context.Context, txHash string, batchId uuid.UUID) error { + query := `update fee_batch set status = $1, tx_hash = $2 where batch_id = $3` + _, err := p.pool.Exec(ctx, query, types.FeeBatchStateSent, txHash, batchId) + return err }