diff --git a/internal/verifierapi/fees.go b/internal/verifierapi/fees.go index d6aacc4..0a8e972 100644 --- a/internal/verifierapi/fees.go +++ b/internal/verifierapi/fees.go @@ -62,6 +62,21 @@ func (v *VerifierApi) CreateFeeBatch(publicKey string) (*FeeBatchCreateResponseD return &feeBatchResponse.Data, nil } +func (v *VerifierApi) GetDraftBatches(publicKey string) ([]FeeBatchCreateResponseDto, error) { + response, err := v.getAuth(fmt.Sprintf("/fees/batch/draft/%s", publicKey)) + if err != nil { + return nil, fmt.Errorf("failed to get draft batches: %w", err) + } + defer response.Body.Close() + + var feeBatches APIResponse[[]FeeBatchCreateResponseDto] + if err := json.NewDecoder(response.Body).Decode(&feeBatches); err != nil { + return nil, fmt.Errorf("failed to decode fee batches response: %w", err) + } + + return feeBatches.Data, nil +} + func (v *VerifierApi) GetFeeHistory(ecdsaPublicKey string) (*FeeHistoryDto, error) { response, err := v.getAuth(fmt.Sprintf("/fees/history/%s", ecdsaPublicKey)) if err != nil { diff --git a/plugin/fees/load.go b/plugin/fees/load.go index be11601..60b10b6 100644 --- a/plugin/fees/load.go +++ b/plugin/fees/load.go @@ -45,7 +45,15 @@ func (fp *FeePlugin) LoadFees(ctx context.Context, task *asynq.Task) error { return fmt.Errorf("failed to acquire semaphore: %w", err) } defer sem.Release(1) - err := fp.executeFeeLoading(ctx, feePolicy) + + // Here we load any existing batches that are in draft state, or that may have been missed along the way. + err := fp.loadExistingBatches(ctx, feePolicy) + if err != nil { + fp.logger.WithError(err).WithField("public_key", feePolicy.PublicKey).Error("Failed to load existing batches") + } + + // Here we create a new batch, later these jobs could run separately on different frequencies. + err = fp.executeFeeLoading(ctx, feePolicy) if err != nil { fp.logger.WithError(err).WithField("public_key", feePolicy.PublicKey).Error("Failed to execute fee loading") } @@ -60,6 +68,51 @@ func (fp *FeePlugin) LoadFees(ctx context.Context, task *asynq.Task) error { return nil } +func (fp *FeePlugin) loadExistingBatches(ctx context.Context, feePolicy vtypes.PluginPolicy) error { + batches, err := fp.verifierApi.GetDraftBatches(feePolicy.PublicKey) + if err != nil { + return fmt.Errorf("failed to get fee batches: %w", err) + } + + for _, batch := range batches { + batches, err := fp.db.GetFeeBatch(ctx, batch.BatchID) + if err != nil { + return err + } + + if len(batches) == 0 { + tx, err := fp.db.Pool().Begin(ctx) + if err != nil { + return err + } + _, err = fp.db.CreateFeeBatch(ctx, tx, types.FeeBatch{ + ID: uuid.New(), + BatchID: batch.BatchID, + PublicKey: feePolicy.PublicKey, + Status: types.FeeBatchStateDraft, + TxHash: nil, + Amount: batch.Amount, + }) + if err != nil { + tx.Rollback(ctx) + return err + } + err = tx.Commit(ctx) + if err != nil { + return err + } + fp.logger.WithField("public_key", feePolicy.PublicKey).WithField("batch_id", batch.BatchID).Info("Created draft batch") + } else { + fp.logger.WithField("public_key", feePolicy.PublicKey).WithField("batch_id", batch.BatchID).Info("Draft batch already exists") + } + } + if len(batches) == 0 { + fp.logger.WithField("public_key", feePolicy.PublicKey).Info("No draft batches found") + } + + return nil +} + func (fp *FeePlugin) executeFeeLoading(ctx context.Context, feePolicy vtypes.PluginPolicy) error { // Get list of fees from the verifier connected to the fee policy @@ -86,5 +139,10 @@ func (fp *FeePlugin) executeFeeLoading(ctx context.Context, feePolicy vtypes.Plu Amount: uint64(batch.Amount), }) - return err + if err != nil { + return fmt.Errorf("failed to create fee batch: %w", err) + } + + fp.logger.WithField("public_key", feePolicy.PublicKey).WithField("batch_id", batch.BatchID).Info("Created draft batch") + return nil } diff --git a/plugin/fees/post_tx.go b/plugin/fees/post_tx.go index 2af2262..b08f38e 100644 --- a/plugin/fees/post_tx.go +++ b/plugin/fees/post_tx.go @@ -3,7 +3,6 @@ package fees import ( "context" "fmt" - "sync" "github.com/ethereum/go-ethereum" ecommon "github.com/ethereum/go-ethereum/common" @@ -31,13 +30,10 @@ func (fp *FeePlugin) HandlePostTx(ctx context.Context, task *asynq.Task) error { } sem := semaphore.NewWeighted(int64(fp.config.Jobs.Post.MaxConcurrentJobs)) - var wg sync.WaitGroup - var eg errgroup.Group + eg, ctx := errgroup.WithContext(ctx) for _, batch := range batches { - wg.Add(1) feeBatch := batch eg.Go(func() error { - defer wg.Done() if err := sem.Acquire(ctx, 1); err != nil { return fmt.Errorf("failed to acquire semaphore: %w", err) } @@ -45,7 +41,6 @@ func (fp *FeePlugin) HandlePostTx(ctx context.Context, task *asynq.Task) error { return fp.updateStatus(ctx, feeBatch, currentBlock) }) } - wg.Wait() if err := eg.Wait(); err != nil { return fmt.Errorf("failed to execute fee run status check: %w", err) } @@ -70,10 +65,33 @@ func (fp *FeePlugin) updateStatus(ctx context.Context, batch types.FeeBatch, cur if currentBlock > receipt.BlockNumber.Uint64()+fp.config.Jobs.Post.SuccessConfirmations { fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx successful, setting to success") - // This is semi critical code as it could create a state mismatch between the verifier and the database. - if err = fp.db.SetFeeBatchStatus(ctx, nil, batch.BatchID, types.FeeBatchStateSuccess); err != nil { + tx, err := fp.db.Pool().Begin(ctx) + if err != nil { + return err + } + var rollbackErr error + defer func() { + if rollbackErr != nil { + tx.Rollback(ctx) + } + }() + + fp.verifierApi.UpdateFeeBatchTxHash(*batch.TxHash, batch.BatchID, *batch.TxHash) + + if err = fp.db.SetFeeBatchStatus(ctx, tx, batch.BatchID, types.FeeBatchStateSuccess); err != nil { + rollbackErr = err + return fmt.Errorf("failed to update verifier fee batch to success: %w", err) + } + + if err = fp.db.SetFeeBatchStatus(ctx, tx, batch.BatchID, types.FeeBatchStateSuccess); err != nil { + rollbackErr = err return fmt.Errorf("failed to set fee batch success: %w", err) } + + if err = tx.Commit(ctx); err != nil { + rollbackErr = err + return fmt.Errorf("failed to commit transaction: %w", err) + } } else { fp.logger.WithFields(logrus.Fields{"batch_id": batch.BatchID}).Info("tx successful, but not enough confirmations, waiting for more") return nil diff --git a/plugin/fees/transaction.go b/plugin/fees/transaction.go index ad55fbc..b4c5f3a 100644 --- a/plugin/fees/transaction.go +++ b/plugin/fees/transaction.go @@ -126,7 +126,7 @@ func (fp *FeePlugin) executeFeeTransaction(ctx context.Context, feeBatch types.F return fmt.Errorf("failed to get vault: %w", err) } - keySignRequests, err := fp.proposeTransactions(ctx, feePolicy, feeBatch.Amount) + keySignRequests, err := fp.proposeTransactions(ctx, feePolicy, feeBatch, feeBatch.Amount) if err != nil { return fmt.Errorf("failed to propose transactions: %w", err) } @@ -170,20 +170,24 @@ func (fp *FeePlugin) initSign( sig = s } - txBytes, txErr := hexutilDecode(req.Transaction) + decodedHexTx, decodedHexTxErr := base64.StdEncoding.DecodeString(req.Transaction) + if decodedHexTxErr != nil { + return fmt.Errorf("failed to decode transaction: %w", decodedHexTxErr) + } + r, rErr := hexutilDecode(sig.R) s, sErr := hexutilDecode(sig.S) v, vErr := hexutilDecode(sig.RecoveryID) - if txErr != nil || rErr != nil || sErr != nil || vErr != nil { - return fmt.Errorf("error decoding tx or sigs: %w", errors.Join(txErr, rErr, sErr, vErr)) + if rErr != nil || sErr != nil || vErr != nil { + return fmt.Errorf("error decoding tx or sigs: %w", errors.Join(rErr, sErr, vErr)) } - txHash, err := getHash(txBytes, r, s, v, fp.config.ChainId) + txHash, err := getHash(decodedHexTx, r, s, v, fp.config.ChainId) if err != nil { return fmt.Errorf("failed to get hash: %w", err) } - erc20tx, err := decodeTx(req.Transaction) + erc20tx, err := decodeTx(hexutil.Encode(decodedHexTx)) if err != nil { fp.logger.WithError(err).Error("failed to decode tx") return fmt.Errorf("failed to decode tx: %w", err) @@ -204,7 +208,7 @@ func (fp *FeePlugin) initSign( "batch_id": feeBatch.BatchID, }).Info("fee collection transaction") - tx, err := fp.eth.Send(ctx, txBytes, r, s, v) + tx, err := fp.eth.Send(ctx, decodedHexTx, r, s, v) if err != nil { fp.logger.WithError(err).WithField("tx_hex", req.Transaction).Error("fp.eth.Send") return fmt.Errorf("failed to send transaction: %w", err) @@ -222,7 +226,7 @@ func (fp *FeePlugin) initSign( } -func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.PluginPolicy, amount uint64) ([]vtypes.PluginKeysignRequest, error) { +func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.PluginPolicy, feeBatch types.FeeBatch, amount uint64) ([]vtypes.PluginKeysignRequest, error) { vault, err := common.GetVaultFromPolicy(fp.vaultStorage, policy, fp.encryptionSecret) if err != nil { @@ -321,6 +325,9 @@ func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.Plug Chain: vgcommon.Ethereum, Hash: base64.StdEncoding.EncodeToString(msgHash[:]), HashFunction: vtypes.HashFunction_SHA256, + CustomData: map[string]interface{}{ + "batch_id": feeBatch.BatchID.String(), + }, }, }, PolicyID: policy.ID, @@ -328,7 +335,6 @@ func (fp *FeePlugin) proposeTransactions(ctx context.Context, policy vtypes.Plug }, Transaction: base64.StdEncoding.EncodeToString(tx), } - txs = append(txs, signRequest) return txs, nil diff --git a/storage/db.go b/storage/db.go index 17a44da..1e9ba06 100644 --- a/storage/db.go +++ b/storage/db.go @@ -21,9 +21,9 @@ type DatabaseStorage interface { InsertPluginPolicyTx(ctx context.Context, dbTx pgx.Tx, policy vtypes.PluginPolicy) (*vtypes.PluginPolicy, error) UpdatePluginPolicyTx(ctx context.Context, dbTx pgx.Tx, policy vtypes.PluginPolicy) (*vtypes.PluginPolicy, error) - CreateFeeBatch(ctx context.Context, tx *pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) - SetFeeBatchTxHash(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, txHash string) error - SetFeeBatchStatus(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error + CreateFeeBatch(ctx context.Context, tx pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) + SetFeeBatchTxHash(ctx context.Context, tx pgx.Tx, batchId uuid.UUID, txHash string) error + SetFeeBatchStatus(ctx context.Context, tx pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error GetFeeBatch(ctx context.Context, batchIDs ...uuid.UUID) ([]types.FeeBatch, error) GetFeeBatchByStatus(ctx context.Context, status types.FeeBatchState) ([]types.FeeBatch, error) SetFeeBatchSent(ctx context.Context, txHash string, batchId uuid.UUID) error diff --git a/storage/postgres/fees.go b/storage/postgres/fees.go index 937416d..cc51670 100644 --- a/storage/postgres/fees.go +++ b/storage/postgres/fees.go @@ -9,25 +9,11 @@ import ( "github.com/vultisig/plugin/internal/types" ) -func (p *PostgresBackend) CreateFeeBatch(ctx context.Context, tx *pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) { - if tx == nil { - _tx, err := p.pool.Begin(ctx) - if err != nil { - return nil, err - } - tx = &_tx - defer func() { - if err != nil { - (*tx).Rollback(ctx) - } - (*tx).Commit(ctx) - }() - } - +func (p *PostgresBackend) CreateFeeBatch(ctx context.Context, tx pgx.Tx, batches ...types.FeeBatch) ([]types.FeeBatch, error) { query := `insert into fee_batch (id, batch_id, public_key, status, amount, tx_hash) values ($1, $2, $3, $4, $5, $6) returning *` feeBatches := make([]types.FeeBatch, 0, len(batches)) for _, batch := range batches { - rows, err := (*tx).Query(ctx, query, batch.ID, batch.BatchID, batch.PublicKey, batch.Status, batch.Amount, batch.TxHash) + rows, err := tx.Query(ctx, query, batch.ID, batch.BatchID, batch.PublicKey, batch.Status, batch.Amount, batch.TxHash) if err != nil { return nil, err } @@ -45,18 +31,9 @@ func (p *PostgresBackend) CreateFeeBatch(ctx context.Context, tx *pgx.Tx, batche return feeBatches, nil } -func (p *PostgresBackend) SetFeeBatchTxHash(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, txHash string) error { - query := `update fee_batch set tx_hash = $1 where batch_id = $2` - _, err := (*tx).Exec(ctx, query, txHash, batchId) - if err != nil { - return err - } - return nil -} - -func (p *PostgresBackend) SetFeeBatchStatus(ctx context.Context, tx *pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error { +func (p *PostgresBackend) SetFeeBatchStatus(ctx context.Context, tx pgx.Tx, batchId uuid.UUID, status types.FeeBatchState) error { query := `update fee_batch set status = $1 where batch_id = $2` - _, err := (*tx).Exec(ctx, query, status, batchId) + _, err := tx.Exec(ctx, query, status, batchId) if err != nil { return err } @@ -64,7 +41,8 @@ func (p *PostgresBackend) SetFeeBatchStatus(ctx context.Context, tx *pgx.Tx, bat } func (p *PostgresBackend) GetFeeBatch(ctx context.Context, batchIDs ...uuid.UUID) ([]types.FeeBatch, error) { - query := `select * from fee_batch where id = $1` + + query := `select * from fee_batch where batch_id = ANY($1)` rows, err := p.pool.Query(ctx, query, batchIDs) if err != nil { return nil, err