Skip to content

Commit cc518ce

Browse files
committed
version 1.4.6
1 parent e3d13a2 commit cc518ce

File tree

9 files changed

+92
-50
lines changed

9 files changed

+92
-50
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ func main() {
144144

145145
### Version
146146

147+
* version 1.4.6 - 2023/07/27
148+
* remove `github.com/goccy/go-json` and set `encoding/json` to default json marshal/unmarshal.
149+
* add `JsonEncoder` and `JsonDecoder` API to adapt other json parser.
150+
147151
* version 1.4.5 - 2023/07/12
148152
* update go.mod
149153
* fix Chinese tokenizer error

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module github.com/sunhailin-Leo/triton-service-go
33
go 1.18
44

55
require (
6-
github.com/goccy/go-json v0.10.2
76
github.com/valyala/fasthttp v1.48.0
87
golang.org/x/text v0.11.0
98
google.golang.org/grpc v1.56.2

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
22
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
3-
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
4-
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
53
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
64
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
75
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=

models/bert/model.go

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ import (
55
"strings"
66
"time"
77

8-
"github.com/goccy/go-json"
9-
"github.com/valyala/fasthttp"
10-
"google.golang.org/grpc"
11-
128
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
139
"github.com/sunhailin-Leo/triton-service-go/utils"
10+
"github.com/valyala/fasthttp"
11+
"google.golang.org/grpc"
1412
)
1513

1614
const (
@@ -25,7 +23,6 @@ const (
2523
)
2624

2725
type ModelService struct {
28-
// isTraceDuration bool
2926
isGRPC bool
3027
isChinese bool
3128
isChineseCharMode bool
@@ -42,18 +39,6 @@ type ModelService struct {
4239

4340
////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////
4441

45-
//// SetModelInferWithTrace Set model infer trace obj.
46-
// func (m *ModelService) SetModelInferWithTrace() *ModelService {
47-
// m.isTraceDuration = true
48-
// return m
49-
// }
50-
//
51-
//// UnsetModelInferWithTrace unset model infer trace obj.
52-
// func (m *ModelService) UnsetModelInferWithTrace() *ModelService {
53-
// m.isTraceDuration = false
54-
// return m
55-
// }
56-
5742
// SetMaxSeqLength Set model infer max sequence length.
5843
func (m *ModelService) SetMaxSeqLength(maxSeqLen int) *ModelService {
5944
m.maxSeqLength = maxSeqLen
@@ -134,6 +119,18 @@ func (m *ModelService) SetSecondaryServerURL(url string) *ModelService {
134119
return m
135120
}
136121

122+
// SetJsonEncoder set json encoder
123+
func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService {
124+
m.tritonService.SetJSONEncoder(encoder)
125+
return m
126+
}
127+
128+
// SetJsonDecoder set json decoder
129+
func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService {
130+
m.tritonService.SetJsonDecoder(decoder)
131+
return m
132+
}
133+
137134
////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////
138135

139136
///////////////////////////////////////// Bert Service Pre-Process Function /////////////////////////////////////////
@@ -265,7 +262,7 @@ func (m *ModelService) generateHTTPRequest(
265262
) ([]byte, []*InputObjects, error) {
266263
// Generate batch request json body
267264
requestInputBody, modelInputObj := m.generateHTTPInputs(inferDataArr, inferInputs)
268-
jsonBody, jsonEncodeErr := json.Marshal(&HTTPRequestBody{
265+
jsonBody, jsonEncodeErr := m.tritonService.JsonMarshal(&HTTPRequestBody{
269266
Inputs: requestInputBody,
270267
Outputs: m.generateHTTPOutputs(inferOutputs),
271268
})

models/bert/tokenizer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func NewBaseTokenizer(opts ...OptionV1) *BaseTokenizer {
9393
// The resulting tokens preserve the alignment with the portion of the original text they belong to.
9494
func (t *BaseTokenizer) Tokenize(text string) []StringOffsetsPair {
9595
splitTokens := make([]StringOffsetsPair, 0)
96+
text = utils.Clean(text)
9697
spaceTokens := t.splitOn(text, utils.IsWhitespace, false)
9798

9899
for i := range spaceTokens {

nvidia_inferenceserver/triton_service_interface.go

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ package nvidia_inferenceserver
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"strconv"
78
"time"
89

9-
"github.com/goccy/go-json"
10+
"github.com/sunhailin-Leo/triton-service-go/utils"
1011
"github.com/valyala/fasthttp"
1112
"google.golang.org/grpc"
1213
)
@@ -122,6 +123,12 @@ type TritonClientService struct {
122123
grpcConn *grpc.ClientConn
123124
grpcClient GRPCInferenceServiceClient
124125
httpClient *fasthttp.Client
126+
127+
// Default: json.Marshal
128+
JSONEncoder utils.JSONMarshal
129+
130+
// Default: json.Unmarshal
131+
JSONDecoder utils.JSONUnmarshal
125132
}
126133

127134
// disconnectToTritonWithGRPC Disconnect GRPC Connection.
@@ -247,6 +254,28 @@ func (t *TritonClientService) decodeFuncErrorHandler(err error, isGRPC bool) err
247254

248255
///////////////////////////////////////////// expose API below /////////////////////////////////////////////
249256

257+
// JsonMarshal Json Encoder
258+
func (t *TritonClientService) JsonMarshal(v interface{}) ([]byte, error) {
259+
return t.JSONEncoder(v)
260+
}
261+
262+
// JsonUnmarshal Json Decoder
263+
func (t *TritonClientService) JsonUnmarshal(data []byte, v interface{}) error {
264+
return t.JSONDecoder(data, v)
265+
}
266+
267+
// SetJSONEncoder set json encoder
268+
func (t *TritonClientService) SetJSONEncoder(encoder utils.JSONMarshal) *TritonClientService {
269+
t.JSONEncoder = encoder
270+
return t
271+
}
272+
273+
// SetJsonDecoder set json decoder
274+
func (t *TritonClientService) SetJsonDecoder(decoder utils.JSONUnmarshal) *TritonClientService {
275+
t.JSONDecoder = decoder
276+
return t
277+
}
278+
250279
// ModelHTTPInfer Call Triton Infer with HTTP.
251280
func (t *TritonClientService) ModelHTTPInfer(
252281
requestBody []byte,
@@ -379,7 +408,7 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta
379408
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
380409
}
381410
serverMetadataResponse := new(ServerMetadataResponse)
382-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil {
411+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil {
383412
return nil, jsonDecodeErr
384413
}
385414
return serverMetadataResponse, nil
@@ -406,7 +435,7 @@ func (t *TritonClientService) ModelMetadataRequest(
406435
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
407436
}
408437
modelMetadataResponse := new(ModelMetadataResponse)
409-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil {
438+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil {
410439
return nil, jsonDecodeErr
411440
}
412441
return modelMetadataResponse, nil
@@ -425,7 +454,7 @@ func (t *TritonClientService) ModelIndex(
425454
ctx, &RepositoryIndexRequest{RepositoryName: repoName, Ready: isReady})
426455
return repositoryIndexResponse, t.grpcErrorHandler(modelIndexErr)
427456
}
428-
reqBody, jsonEncodeErr := json.Marshal(&ModelIndexRequestHTTPObj{repoName, isReady})
457+
reqBody, jsonEncodeErr := t.JsonMarshal(&ModelIndexRequestHTTPObj{repoName, isReady})
429458
if jsonEncodeErr != nil {
430459
return nil, jsonEncodeErr
431460
}
@@ -435,7 +464,7 @@ func (t *TritonClientService) ModelIndex(
435464
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
436465
}
437466
repositoryIndexResponse := new(RepositoryIndexResponse)
438-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil {
467+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil {
439468
return nil, jsonDecodeErr
440469
}
441470
return repositoryIndexResponse, nil
@@ -461,7 +490,7 @@ func (t *TritonClientService) ModelConfiguration(
461490
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
462491
}
463492
modelConfigResponse := new(ModelConfigResponse)
464-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil {
493+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil {
465494
return nil, jsonDecodeErr
466495
}
467496
return modelConfigResponse, nil
@@ -487,7 +516,7 @@ func (t *TritonClientService) ModelInferStats(
487516
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
488517
}
489518
modelStatisticsResponse := new(ModelStatisticsResponse)
490-
jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelStatisticsResponse)
519+
jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelStatisticsResponse)
491520
if jsonDecodeErr != nil {
492521
return nil, jsonDecodeErr
493522
}
@@ -507,7 +536,7 @@ func (t *TritonClientService) ModelLoadWithHTTP(
507536
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
508537
}
509538
repositoryModelLoadResponse := new(RepositoryModelLoadResponse)
510-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil {
539+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil {
511540
return nil, jsonDecodeErr
512541
}
513542
return repositoryModelLoadResponse, nil
@@ -539,7 +568,7 @@ func (t *TritonClientService) ModelUnloadWithHTTP(
539568
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
540569
}
541570
repositoryModelUnloadResponse := new(RepositoryModelUnloadResponse)
542-
jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelUnloadResponse)
571+
jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelUnloadResponse)
543572
if jsonDecodeErr != nil {
544573
return nil, jsonDecodeErr
545574
}
@@ -601,13 +630,13 @@ func (t *TritonClientService) ShareMemoryStatus(
601630
// Parse Response
602631
if isCUDA {
603632
cudaSharedMemoryStatusResponse := new(CudaSharedMemoryStatusResponse)
604-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil {
633+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil {
605634
return nil, jsonDecodeErr
606635
}
607636
return cudaSharedMemoryStatusResponse, nil
608637
}
609638
systemSharedMemoryStatusResponse := new(SystemSharedMemoryStatusResponse)
610-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil {
639+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil {
611640
return nil, jsonDecodeErr
612641
}
613642
return systemSharedMemoryStatusResponse, nil
@@ -632,7 +661,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
632661
)
633662
return cudaSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr)
634663
}
635-
reqBody, jsonEncodeErr := json.Marshal(
664+
reqBody, jsonEncodeErr := t.JsonMarshal(
636665
&CudaMemoryRegisterBodyHTTPObj{cudaRawHandle, cudaDeviceID, byteSize})
637666
if jsonEncodeErr != nil {
638667
return nil, jsonEncodeErr
@@ -644,7 +673,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
644673
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
645674
}
646675
cudaSharedMemoryRegisterResponse := new(CudaSharedMemoryRegisterResponse)
647-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil {
676+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil {
648677
return nil, jsonDecodeErr
649678
}
650679
return cudaSharedMemoryRegisterResponse, nil
@@ -670,7 +699,7 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister(
670699
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
671700
}
672701
cudaSharedMemoryUnregisterResponse := new(CudaSharedMemoryUnregisterResponse)
673-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
702+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
674703
return nil, jsonDecodeErr
675704
}
676705
return cudaSharedMemoryUnregisterResponse, nil
@@ -695,7 +724,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
695724
)
696725
return systemSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr)
697726
}
698-
reqBody, jsonEncodeErr := json.Marshal(
727+
reqBody, jsonEncodeErr := t.JsonMarshal(
699728
&SystemMemoryRegisterBodyHTTPObj{cpuMemRegionKey, cpuMemOffset, byteSize})
700729
if jsonEncodeErr != nil {
701730
return nil, jsonEncodeErr
@@ -707,7 +736,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
707736
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
708737
}
709738
systemSharedMemoryRegisterResponse := new(SystemSharedMemoryRegisterResponse)
710-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil {
739+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil {
711740
return nil, jsonDecodeErr
712741
}
713742
return systemSharedMemoryRegisterResponse, nil
@@ -733,7 +762,7 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister(
733762
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
734763
}
735764
systemSharedMemoryUnregisterResponse := new(SystemSharedMemoryUnregisterResponse)
736-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
765+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
737766
return nil, jsonDecodeErr
738767
}
739768
return systemSharedMemoryUnregisterResponse, nil
@@ -759,7 +788,7 @@ func (t *TritonClientService) GetModelTracingSetting(
759788
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
760789
}
761790
traceSettingResponse := new(TraceSettingResponse)
762-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
791+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
763792
return nil, jsonDecodeErr
764793
}
765794
return traceSettingResponse, nil
@@ -780,7 +809,7 @@ func (t *TritonClientService) SetModelTracingSetting(
780809
return traceSettingResponse, t.grpcErrorHandler(setTraceSettingErr)
781810
}
782811
// Experimental
783-
reqBody, jsonEncodeErr := json.Marshal(&TraceSettingRequestHTTPObj{settingMap})
812+
reqBody, jsonEncodeErr := t.JsonMarshal(&TraceSettingRequestHTTPObj{settingMap})
784813
if jsonEncodeErr != nil {
785814
return nil, jsonEncodeErr
786815
}
@@ -791,7 +820,7 @@ func (t *TritonClientService) SetModelTracingSetting(
791820
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
792821
}
793822
traceSettingResponse := new(TraceSettingResponse)
794-
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
823+
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
795824
return nil, jsonDecodeErr
796825
}
797826
return traceSettingResponse, nil
@@ -818,7 +847,7 @@ func (t *TritonClientService) ShutdownTritonConnection() (disconnectionErr error
818847

819848
// NewTritonClientWithOnlyHTTP init triton client.
820849
func NewTritonClientWithOnlyHTTP(uri string, httpClient *fasthttp.Client) *TritonClientService {
821-
client := &TritonClientService{serverURL: uri}
850+
client := &TritonClientService{serverURL: uri, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal}
822851
client.setHTTPConnection(httpClient)
823852
return client
824853
}
@@ -828,7 +857,12 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService
828857
if grpcConn == nil {
829858
return nil
830859
}
831-
client := &TritonClientService{grpcConn: grpcConn, grpcClient: NewGRPCInferenceServiceClient(grpcConn)}
860+
client := &TritonClientService{
861+
grpcConn: grpcConn,
862+
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
863+
JSONEncoder: json.Marshal,
864+
JSONDecoder: json.Unmarshal,
865+
}
832866
return client
833867
}
834868

@@ -837,9 +871,11 @@ func NewTritonClientForAll(
837871
httpServerURL string, httpClient *fasthttp.Client, grpcConn *grpc.ClientConn,
838872
) *TritonClientService {
839873
client := &TritonClientService{
840-
serverURL: httpServerURL,
841-
grpcConn: grpcConn,
842-
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
874+
serverURL: httpServerURL,
875+
grpcConn: grpcConn,
876+
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
877+
JSONEncoder: json.Marshal,
878+
JSONDecoder: json.Unmarshal,
843879
}
844880
client.setHTTPConnection(httpClient)
845881

test/bert_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ import (
55
"testing"
66

77
"github.com/sunhailin-Leo/triton-service-go/models/bert"
8+
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
89
"github.com/valyala/fasthttp"
910
"google.golang.org/grpc"
1011
"google.golang.org/grpc/credentials/insecure"
11-
12-
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
1312
)
1413

1514
const (

test/triton_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ import (
44
"errors"
55
"testing"
66

7+
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
78
"github.com/valyala/fasthttp"
89
"google.golang.org/grpc"
910
"google.golang.org/grpc/credentials/insecure"
10-
11-
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
1211
)
1312

1413
func TestTritonHTTPClientInit(_ *testing.T) {

0 commit comments

Comments
 (0)