@@ -2,11 +2,12 @@ package nvidia_inferenceserver
2
2
3
3
import (
4
4
"context"
5
+ "encoding/json"
5
6
"errors"
6
7
"strconv"
7
8
"time"
8
9
9
- "github.com/goccy/go-json "
10
+ "github.com/sunhailin-Leo/triton-service-go/utils "
10
11
"github.com/valyala/fasthttp"
11
12
"google.golang.org/grpc"
12
13
)
@@ -122,6 +123,12 @@ type TritonClientService struct {
122
123
grpcConn * grpc.ClientConn
123
124
grpcClient GRPCInferenceServiceClient
124
125
httpClient * fasthttp.Client
126
+
127
+ // Default: json.Marshal
128
+ JSONEncoder utils.JSONMarshal
129
+
130
+ // Default: json.Unmarshal
131
+ JSONDecoder utils.JSONUnmarshal
125
132
}
126
133
127
134
// disconnectToTritonWithGRPC Disconnect GRPC Connection.
@@ -247,6 +254,28 @@ func (t *TritonClientService) decodeFuncErrorHandler(err error, isGRPC bool) err
247
254
248
255
///////////////////////////////////////////// expose API below /////////////////////////////////////////////
249
256
257
+ // JsonMarshal Json Encoder
258
+ func (t * TritonClientService ) JsonMarshal (v interface {}) ([]byte , error ) {
259
+ return t .JSONEncoder (v )
260
+ }
261
+
262
+ // JsonUnmarshal Json Decoder
263
+ func (t * TritonClientService ) JsonUnmarshal (data []byte , v interface {}) error {
264
+ return t .JSONDecoder (data , v )
265
+ }
266
+
267
+ // SetJSONEncoder set json encoder
268
+ func (t * TritonClientService ) SetJSONEncoder (encoder utils.JSONMarshal ) * TritonClientService {
269
+ t .JSONEncoder = encoder
270
+ return t
271
+ }
272
+
273
+ // SetJsonDecoder set json decoder
274
+ func (t * TritonClientService ) SetJsonDecoder (decoder utils.JSONUnmarshal ) * TritonClientService {
275
+ t .JSONDecoder = decoder
276
+ return t
277
+ }
278
+
250
279
// ModelHTTPInfer Call Triton Infer with HTTP.
251
280
func (t * TritonClientService ) ModelHTTPInfer (
252
281
requestBody []byte ,
@@ -379,7 +408,7 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta
379
408
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
380
409
}
381
410
serverMetadataResponse := new (ServerMetadataResponse )
382
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & serverMetadataResponse ); jsonDecodeErr != nil {
411
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & serverMetadataResponse ); jsonDecodeErr != nil {
383
412
return nil , jsonDecodeErr
384
413
}
385
414
return serverMetadataResponse , nil
@@ -406,7 +435,7 @@ func (t *TritonClientService) ModelMetadataRequest(
406
435
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
407
436
}
408
437
modelMetadataResponse := new (ModelMetadataResponse )
409
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & modelMetadataResponse ); jsonDecodeErr != nil {
438
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & modelMetadataResponse ); jsonDecodeErr != nil {
410
439
return nil , jsonDecodeErr
411
440
}
412
441
return modelMetadataResponse , nil
@@ -425,7 +454,7 @@ func (t *TritonClientService) ModelIndex(
425
454
ctx , & RepositoryIndexRequest {RepositoryName : repoName , Ready : isReady })
426
455
return repositoryIndexResponse , t .grpcErrorHandler (modelIndexErr )
427
456
}
428
- reqBody , jsonEncodeErr := json . Marshal (& ModelIndexRequestHTTPObj {repoName , isReady })
457
+ reqBody , jsonEncodeErr := t . JsonMarshal (& ModelIndexRequestHTTPObj {repoName , isReady })
429
458
if jsonEncodeErr != nil {
430
459
return nil , jsonEncodeErr
431
460
}
@@ -435,7 +464,7 @@ func (t *TritonClientService) ModelIndex(
435
464
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
436
465
}
437
466
repositoryIndexResponse := new (RepositoryIndexResponse )
438
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & repositoryIndexResponse .Models ); jsonDecodeErr != nil {
467
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & repositoryIndexResponse .Models ); jsonDecodeErr != nil {
439
468
return nil , jsonDecodeErr
440
469
}
441
470
return repositoryIndexResponse , nil
@@ -461,7 +490,7 @@ func (t *TritonClientService) ModelConfiguration(
461
490
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
462
491
}
463
492
modelConfigResponse := new (ModelConfigResponse )
464
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & modelConfigResponse ); jsonDecodeErr != nil {
493
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & modelConfigResponse ); jsonDecodeErr != nil {
465
494
return nil , jsonDecodeErr
466
495
}
467
496
return modelConfigResponse , nil
@@ -487,7 +516,7 @@ func (t *TritonClientService) ModelInferStats(
487
516
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
488
517
}
489
518
modelStatisticsResponse := new (ModelStatisticsResponse )
490
- jsonDecodeErr := json . Unmarshal (apiResp .Body (), & modelStatisticsResponse )
519
+ jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & modelStatisticsResponse )
491
520
if jsonDecodeErr != nil {
492
521
return nil , jsonDecodeErr
493
522
}
@@ -507,7 +536,7 @@ func (t *TritonClientService) ModelLoadWithHTTP(
507
536
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
508
537
}
509
538
repositoryModelLoadResponse := new (RepositoryModelLoadResponse )
510
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & repositoryModelLoadResponse ); jsonDecodeErr != nil {
539
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & repositoryModelLoadResponse ); jsonDecodeErr != nil {
511
540
return nil , jsonDecodeErr
512
541
}
513
542
return repositoryModelLoadResponse , nil
@@ -539,7 +568,7 @@ func (t *TritonClientService) ModelUnloadWithHTTP(
539
568
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
540
569
}
541
570
repositoryModelUnloadResponse := new (RepositoryModelUnloadResponse )
542
- jsonDecodeErr := json . Unmarshal (apiResp .Body (), & repositoryModelUnloadResponse )
571
+ jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & repositoryModelUnloadResponse )
543
572
if jsonDecodeErr != nil {
544
573
return nil , jsonDecodeErr
545
574
}
@@ -601,13 +630,13 @@ func (t *TritonClientService) ShareMemoryStatus(
601
630
// Parse Response
602
631
if isCUDA {
603
632
cudaSharedMemoryStatusResponse := new (CudaSharedMemoryStatusResponse )
604
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & cudaSharedMemoryStatusResponse ); jsonDecodeErr != nil {
633
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & cudaSharedMemoryStatusResponse ); jsonDecodeErr != nil {
605
634
return nil , jsonDecodeErr
606
635
}
607
636
return cudaSharedMemoryStatusResponse , nil
608
637
}
609
638
systemSharedMemoryStatusResponse := new (SystemSharedMemoryStatusResponse )
610
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & systemSharedMemoryStatusResponse ); jsonDecodeErr != nil {
639
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & systemSharedMemoryStatusResponse ); jsonDecodeErr != nil {
611
640
return nil , jsonDecodeErr
612
641
}
613
642
return systemSharedMemoryStatusResponse , nil
@@ -632,7 +661,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
632
661
)
633
662
return cudaSharedMemoryRegisterResponse , t .grpcErrorHandler (registerErr )
634
663
}
635
- reqBody , jsonEncodeErr := json . Marshal (
664
+ reqBody , jsonEncodeErr := t . JsonMarshal (
636
665
& CudaMemoryRegisterBodyHTTPObj {cudaRawHandle , cudaDeviceID , byteSize })
637
666
if jsonEncodeErr != nil {
638
667
return nil , jsonEncodeErr
@@ -644,7 +673,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
644
673
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
645
674
}
646
675
cudaSharedMemoryRegisterResponse := new (CudaSharedMemoryRegisterResponse )
647
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & cudaSharedMemoryRegisterResponse ); jsonDecodeErr != nil {
676
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & cudaSharedMemoryRegisterResponse ); jsonDecodeErr != nil {
648
677
return nil , jsonDecodeErr
649
678
}
650
679
return cudaSharedMemoryRegisterResponse , nil
@@ -670,7 +699,7 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister(
670
699
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
671
700
}
672
701
cudaSharedMemoryUnregisterResponse := new (CudaSharedMemoryUnregisterResponse )
673
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & cudaSharedMemoryUnregisterResponse ); jsonDecodeErr != nil {
702
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & cudaSharedMemoryUnregisterResponse ); jsonDecodeErr != nil {
674
703
return nil , jsonDecodeErr
675
704
}
676
705
return cudaSharedMemoryUnregisterResponse , nil
@@ -695,7 +724,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
695
724
)
696
725
return systemSharedMemoryRegisterResponse , t .grpcErrorHandler (registerErr )
697
726
}
698
- reqBody , jsonEncodeErr := json . Marshal (
727
+ reqBody , jsonEncodeErr := t . JsonMarshal (
699
728
& SystemMemoryRegisterBodyHTTPObj {cpuMemRegionKey , cpuMemOffset , byteSize })
700
729
if jsonEncodeErr != nil {
701
730
return nil , jsonEncodeErr
@@ -707,7 +736,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
707
736
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
708
737
}
709
738
systemSharedMemoryRegisterResponse := new (SystemSharedMemoryRegisterResponse )
710
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & systemSharedMemoryRegisterResponse ); jsonDecodeErr != nil {
739
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & systemSharedMemoryRegisterResponse ); jsonDecodeErr != nil {
711
740
return nil , jsonDecodeErr
712
741
}
713
742
return systemSharedMemoryRegisterResponse , nil
@@ -733,7 +762,7 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister(
733
762
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
734
763
}
735
764
systemSharedMemoryUnregisterResponse := new (SystemSharedMemoryUnregisterResponse )
736
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), & systemSharedMemoryUnregisterResponse ); jsonDecodeErr != nil {
765
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), & systemSharedMemoryUnregisterResponse ); jsonDecodeErr != nil {
737
766
return nil , jsonDecodeErr
738
767
}
739
768
return systemSharedMemoryUnregisterResponse , nil
@@ -759,7 +788,7 @@ func (t *TritonClientService) GetModelTracingSetting(
759
788
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
760
789
}
761
790
traceSettingResponse := new (TraceSettingResponse )
762
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), traceSettingResponse ); jsonDecodeErr != nil {
791
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), traceSettingResponse ); jsonDecodeErr != nil {
763
792
return nil , jsonDecodeErr
764
793
}
765
794
return traceSettingResponse , nil
@@ -780,7 +809,7 @@ func (t *TritonClientService) SetModelTracingSetting(
780
809
return traceSettingResponse , t .grpcErrorHandler (setTraceSettingErr )
781
810
}
782
811
// Experimental
783
- reqBody , jsonEncodeErr := json . Marshal (& TraceSettingRequestHTTPObj {settingMap })
812
+ reqBody , jsonEncodeErr := t . JsonMarshal (& TraceSettingRequestHTTPObj {settingMap })
784
813
if jsonEncodeErr != nil {
785
814
return nil , jsonEncodeErr
786
815
}
@@ -791,7 +820,7 @@ func (t *TritonClientService) SetModelTracingSetting(
791
820
return nil , t .httpErrorHandler (apiResp .StatusCode (), httpErr )
792
821
}
793
822
traceSettingResponse := new (TraceSettingResponse )
794
- if jsonDecodeErr := json . Unmarshal (apiResp .Body (), traceSettingResponse ); jsonDecodeErr != nil {
823
+ if jsonDecodeErr := t . JSONDecoder (apiResp .Body (), traceSettingResponse ); jsonDecodeErr != nil {
795
824
return nil , jsonDecodeErr
796
825
}
797
826
return traceSettingResponse , nil
@@ -818,7 +847,7 @@ func (t *TritonClientService) ShutdownTritonConnection() (disconnectionErr error
818
847
819
848
// NewTritonClientWithOnlyHTTP init triton client.
820
849
func NewTritonClientWithOnlyHTTP (uri string , httpClient * fasthttp.Client ) * TritonClientService {
821
- client := & TritonClientService {serverURL : uri }
850
+ client := & TritonClientService {serverURL : uri , JSONEncoder : json . Marshal , JSONDecoder : json . Unmarshal }
822
851
client .setHTTPConnection (httpClient )
823
852
return client
824
853
}
@@ -828,7 +857,12 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService
828
857
if grpcConn == nil {
829
858
return nil
830
859
}
831
- client := & TritonClientService {grpcConn : grpcConn , grpcClient : NewGRPCInferenceServiceClient (grpcConn )}
860
+ client := & TritonClientService {
861
+ grpcConn : grpcConn ,
862
+ grpcClient : NewGRPCInferenceServiceClient (grpcConn ),
863
+ JSONEncoder : json .Marshal ,
864
+ JSONDecoder : json .Unmarshal ,
865
+ }
832
866
return client
833
867
}
834
868
@@ -837,9 +871,11 @@ func NewTritonClientForAll(
837
871
httpServerURL string , httpClient * fasthttp.Client , grpcConn * grpc.ClientConn ,
838
872
) * TritonClientService {
839
873
client := & TritonClientService {
840
- serverURL : httpServerURL ,
841
- grpcConn : grpcConn ,
842
- grpcClient : NewGRPCInferenceServiceClient (grpcConn ),
874
+ serverURL : httpServerURL ,
875
+ grpcConn : grpcConn ,
876
+ grpcClient : NewGRPCInferenceServiceClient (grpcConn ),
877
+ JSONEncoder : json .Marshal ,
878
+ JSONDecoder : json .Unmarshal ,
843
879
}
844
880
client .setHTTPConnection (httpClient )
845
881
0 commit comments