Skip to content

chore: use fmt.Errorf instead of errors.Wrap #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apierrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import (
var (
ProvisionWarehouseTimeout = "ProvisionWarehouseTimeout"

ErrDoRequest = errors.New("DoReqeustFailed")
ErrReadResponse = errors.New("ReadResponseFailed")
ErrDoRequest = errors.New("failed to do request")
ErrReadResponse = errors.New("failed to read response")
)

type APIErrorResponseBody struct {
Expand Down
70 changes: 38 additions & 32 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql/driver"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
Expand All @@ -18,7 +19,6 @@ import (

"github.com/avast/retry-go"
"github.com/google/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)

Expand Down Expand Up @@ -229,48 +229,54 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
if req != nil {
reqBody, err = json.Marshal(req)
if err != nil {
return errors.Wrap(err, "failed to marshal request body")
return fmt.Errorf("failed to marshal request body: %w", err)
}
}

url := c.makeURL(path)
httpReq, err := http.NewRequest(method, url, bytes.NewBuffer(reqBody))
if err != nil {
return errors.Wrap(err, "failed to create http request")
return fmt.Errorf("failed to create http request: %w", err)
}
httpReq = httpReq.WithContext(ctx)

maxRetries := 2
for i := 1; i <= maxRetries; i++ {
// do not retry if context is canceled
select {
case <-ctx.Done():
return ctx.Err()
default:
}

headers, err := c.makeHeaders(ctx)
if err != nil {
return fmt.Errorf("failed to make request headers: %w", err)
}
if needSticky && len(c.NodeID) != 0 {
headers.Set(DatabendQueryStickyNode, c.NodeID)
}
if err != nil {
return errors.Wrap(err, "failed to make request headers")
}
if method == "GET" && len(c.NodeID) != 0 {
headers.Set(DatabendQueryIDNode, c.NodeID)
}
headers.Set(contentType, jsonContentType)
headers.Set(accept, jsonContentType)
httpReq.Header = headers

if len(c.host) > 0 {
httpReq.Host = c.host
}

httpResp, err := c.cli.Do(httpReq)
if err != nil {
return errors.Wrap(ErrDoRequest, err.Error())
return errors.Join(ErrDoRequest, err)
}
defer func() {
_ = httpResp.Body.Close()
}()

httpRespBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return errors.Wrap(ErrReadResponse, err.Error())
return errors.Join(ErrReadResponse, err)
}

if httpResp.StatusCode == http.StatusUnauthorized {
Expand All @@ -292,7 +298,7 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
contentType := httpResp.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
if err := json.Unmarshal(httpRespBody, &resp); err != nil {
return errors.Wrap(err, "failed to unmarshal response body")
return fmt.Errorf("failed to unmarshal response body: %w", err)
}
}
}
Expand All @@ -301,7 +307,7 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
}
return nil
}
return errors.Errorf("failed to do request after %d retries", maxRetries)
return fmt.Errorf("failed to do request after %d retries", maxRetries)
}

func (c *APIClient) trackStats(resp *QueryResponse) {
Expand Down Expand Up @@ -355,11 +361,11 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) {
case AuthMethodAccessToken:
accessToken, err := c.accessTokenLoader.LoadAccessToken(context.TODO(), false)
if err != nil {
return nil, errors.Wrap(err, "failed to load access token")
return nil, fmt.Errorf("failed to load access token: %w", err)
}
headers.Set(Authorization, fmt.Sprintf("Bearer %s", accessToken))
default:
return nil, errors.New("no user password or access token")
return nil, fmt.Errorf("no user password or access token")
}

return headers, nil
Expand Down Expand Up @@ -426,7 +432,7 @@ func (c *APIClient) PollUntilQueryEnd(ctx context.Context, resp *QueryResponse)
return nil, err
}
if resp.Error != nil {
return nil, errors.Wrap(resp.Error, "query page has error")
return nil, fmt.Errorf("query page has error: %w", resp.Error)
}
resp.Data = append(data, resp.Data...)
}
Expand All @@ -437,7 +443,7 @@ func buildQuery(query string, params []driver.Value) (string, error) {
if len(params) > 0 && params[0] != nil {
result, err := interpolateParams(query, params)
if err != nil {
return result, errors.Wrap(err, "buildRequest: failed to interpolate params")
return result, fmt.Errorf("buildRequest: failed to interpolate params: %w", err)
}
return result, nil
}
Expand Down Expand Up @@ -508,7 +514,7 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest
}, Query,
)
if err != nil {
return nil, errors.Wrap(err, "failed to do query request")
return nil, fmt.Errorf("failed to do query request: %w", err)
}

if len(resp.NodeID) != 0 {
Expand Down Expand Up @@ -551,7 +557,7 @@ func (c *APIClient) PollQuery(ctx context.Context, nextURI string) (*QueryRespon
c.applySessionState(&result)
c.trackStats(&result)
if err != nil {
return nil, errors.Wrap(err, "failed to query page")
return nil, fmt.Errorf("failed to query page: %w", err)
}
return &result, nil
}
Expand Down Expand Up @@ -608,7 +614,7 @@ func (c *APIClient) InsertWithStage(ctx context.Context, sql string, stage *Stag
_ = c.CloseQuery(ctx, resp)
}()
if resp.Error != nil {
return nil, errors.Wrap(resp.Error, "query error:")
return nil, fmt.Errorf("query error: %w", resp.Error)
}
return c.PollUntilQueryEnd(ctx, resp)
}
Expand All @@ -625,20 +631,20 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (
presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage)
resp, err := c.QuerySync(ctx, presignUploadSQL, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to query presign url")
return nil, fmt.Errorf("failed to query presign url: %w", err)
}
if len(resp.Data) < 1 || len(resp.Data[0]) < 2 {
return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data)
return nil, fmt.Errorf("generate presign url invalid response: %+v", resp.Data)
}
if resp.Data[0][0] == nil || resp.Data[0][1] == nil || resp.Data[0][2] == nil {
return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data)
return nil, fmt.Errorf("generate presign url invalid response: %+v", resp.Data)
}
method := *resp.Data[0][0]
url := *resp.Data[0][2]
headers := map[string]string{}
err = json.Unmarshal([]byte(*resp.Data[0][1]), &headers)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal headers")
return nil, fmt.Errorf("failed to unmarshal headers: %w", err)
}
result := &PresignedResponse{
Method: method,
Expand All @@ -651,7 +657,7 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (
func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error {
presigned, err := c.GetPresignedURL(ctx, stage)
if err != nil {
return errors.Wrap(err, "failed to get presigned url")
return fmt.Errorf("failed to get presigned url: %w", err)
}

req, err := http.NewRequest("PUT", presigned.URL, input)
Expand All @@ -668,7 +674,7 @@ func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageL
}
resp, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "failed to upload to stage by presigned url")
return fmt.Errorf("failed to upload to stage by presigned url: %w", err)
}
defer func() {
_ = resp.Body.Close()
Expand All @@ -678,7 +684,7 @@ func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageL
return err
}
if resp.StatusCode >= 400 {
return errors.Errorf("failed to upload to stage by presigned url, status code: %d, body: %s", resp.StatusCode, string(respBody))
return fmt.Errorf("failed to upload to stage by presigned url, status code: %d, body: %s", resp.StatusCode, string(respBody))
}
return nil
}
Expand All @@ -688,28 +694,28 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("upload", stage.Path)
if err != nil {
return errors.Wrap(err, "failed to create multipart writer form file")
return fmt.Errorf("failed to create multipart writer form file: %w", err)
}
// TODO: do async upload
_, err = io.Copy(part, input)
if err != nil {
return errors.Wrap(err, "failed to copy file to multipart writer form file")
return fmt.Errorf("failed to copy file to multipart writer form file: %w", err)
}
err = writer.Close()
if err != nil {
return errors.Wrap(err, "failed to close multipart writer")
return fmt.Errorf("failed to close multipart writer: %w", err)
}

path := "/v1/upload_to_stage"
url := c.makeURL(path)
req, err := http.NewRequest("PUT", url, body)
if err != nil {
return errors.Wrap(err, "failed to create http request")
return fmt.Errorf("failed to create http request: %w", err)
}

req.Header, err = c.makeHeaders(ctx)
if err != nil {
return errors.Wrap(err, "failed to make headers")
return fmt.Errorf("failed to make headers: %w", err)
}
if len(c.host) > 0 {
req.Host = c.host
Expand All @@ -723,15 +729,15 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation
}
resp, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "failed http do request")
return fmt.Errorf("failed http do request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return errors.Wrap(err, "failed to read http response body")
return fmt.Errorf("failed to read http response body: %w", err)
}

if resp.StatusCode == http.StatusUnauthorized {
Expand Down
Loading