Skip to content

Commit 59dc759

Browse files
filintodcicoyle
andauthored
conversation api - refactor langchaingo common models (#3846)
Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> Co-authored-by: Cassie Coyle <cassie@diagrid.io>
1 parent 6cc652c commit 59dc759

File tree

34 files changed

+879
-374
lines changed

34 files changed

+879
-374
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ go.work.sum
99
component-metadata-bundle.json
1010
*.log
1111
metadataanalyzer
12+
**/.env

conversation/anthropic/anthropic.go

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ import (
1919
"reflect"
2020

2121
"github.com/dapr/components-contrib/conversation"
22+
"github.com/dapr/components-contrib/conversation/langchaingokit"
2223
"github.com/dapr/components-contrib/metadata"
2324
"github.com/dapr/kit/logger"
2425
kmeta "github.com/dapr/kit/metadata"
2526

26-
"github.com/tmc/langchaingo/llms"
2727
"github.com/tmc/langchaingo/llms/anthropic"
2828
)
2929

3030
type Anthropic struct {
31-
llm llms.Model
31+
langchaingokit.LLM
3232

3333
logger logger.Logger
3434
}
@@ -63,15 +63,15 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error
6363
return err
6464
}
6565

66-
a.llm = llm
66+
a.LLM.Model = llm
6767

6868
if m.CacheTTL != "" {
69-
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.llm)
69+
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.LLM.Model)
7070
if cacheErr != nil {
7171
return cacheErr
7272
}
7373

74-
a.llm = cachedModel
74+
a.LLM.Model = cachedModel
7575
}
7676

7777
return nil
@@ -83,47 +83,6 @@ func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
8383
return
8484
}
8585

86-
func (a *Anthropic) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
87-
messages := make([]llms.MessageContent, 0, len(r.Inputs))
88-
89-
for _, input := range r.Inputs {
90-
role := conversation.ConvertLangchainRole(input.Role)
91-
92-
messages = append(messages, llms.MessageContent{
93-
Role: role,
94-
Parts: []llms.ContentPart{
95-
llms.TextPart(input.Message),
96-
},
97-
})
98-
}
99-
100-
opts := []llms.CallOption{}
101-
102-
if r.Temperature > 0 {
103-
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
104-
}
105-
106-
resp, err := a.llm.GenerateContent(ctx, messages, opts...)
107-
if err != nil {
108-
return nil, err
109-
}
110-
111-
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
112-
113-
for i := range resp.Choices {
114-
outputs = append(outputs, conversation.ConversationResult{
115-
Result: resp.Choices[i].Content,
116-
Parameters: r.Parameters,
117-
})
118-
}
119-
120-
res = &conversation.ConversationResponse{
121-
Outputs: outputs,
122-
}
123-
124-
return res, nil
125-
}
126-
12786
func (a *Anthropic) Close() error {
12887
return nil
12988
}

conversation/aws/bedrock/bedrock.go

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ import (
2020

2121
awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
2222
"github.com/dapr/components-contrib/conversation"
23+
"github.com/dapr/components-contrib/conversation/langchaingokit"
2324
"github.com/dapr/components-contrib/metadata"
2425
"github.com/dapr/kit/logger"
2526
kmeta "github.com/dapr/kit/metadata"
2627

2728
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
28-
"github.com/tmc/langchaingo/llms"
2929
"github.com/tmc/langchaingo/llms/bedrock"
3030
)
3131

3232
type AWSBedrock struct {
3333
model string
34-
llm llms.Model
34+
langchaingokit.LLM
3535

3636
logger logger.Logger
3737
}
@@ -81,15 +81,15 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error
8181
return err
8282
}
8383

84-
b.llm = llm
84+
b.LLM.Model = llm
8585

8686
if m.CacheTTL != "" {
87-
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.llm)
87+
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.LLM.Model)
8888
if cacheErr != nil {
8989
return cacheErr
9090
}
9191

92-
b.llm = cachedModel
92+
b.LLM.Model = cachedModel
9393
}
9494
return nil
9595
}
@@ -100,47 +100,6 @@ func (b *AWSBedrock) GetComponentMetadata() (metadataInfo metadata.MetadataMap)
100100
return
101101
}
102102

103-
func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
104-
messages := make([]llms.MessageContent, 0, len(r.Inputs))
105-
106-
for _, input := range r.Inputs {
107-
role := conversation.ConvertLangchainRole(input.Role)
108-
109-
messages = append(messages, llms.MessageContent{
110-
Role: role,
111-
Parts: []llms.ContentPart{
112-
llms.TextPart(input.Message),
113-
},
114-
})
115-
}
116-
117-
opts := []llms.CallOption{}
118-
119-
if r.Temperature > 0 {
120-
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
121-
}
122-
123-
resp, err := b.llm.GenerateContent(ctx, messages, opts...)
124-
if err != nil {
125-
return nil, err
126-
}
127-
128-
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
129-
130-
for i := range resp.Choices {
131-
outputs = append(outputs, conversation.ConversationResult{
132-
Result: resp.Choices[i].Content,
133-
Parameters: r.Parameters,
134-
})
135-
}
136-
137-
res = &conversation.ConversationResponse{
138-
Outputs: outputs,
139-
}
140-
141-
return res, nil
142-
}
143-
144103
func (b *AWSBedrock) Close() error {
145104
return nil
146105
}

conversation/googleai/googleai.go

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ import (
1919
"reflect"
2020

2121
"github.com/dapr/components-contrib/conversation"
22+
"github.com/dapr/components-contrib/conversation/langchaingokit"
2223
"github.com/dapr/components-contrib/metadata"
2324
"github.com/dapr/kit/logger"
2425
kmeta "github.com/dapr/kit/metadata"
2526

26-
"github.com/tmc/langchaingo/llms"
2727
"github.com/tmc/langchaingo/llms/googleai"
2828
)
2929

3030
type GoogleAI struct {
31-
llm llms.Model
31+
langchaingokit.LLM
3232

3333
logger logger.Logger
3434
}
@@ -67,15 +67,15 @@ func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {
6767
return err
6868
}
6969

70-
g.llm = llm
70+
g.LLM.Model = llm
7171

7272
if md.CacheTTL != "" {
73-
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, g.llm)
73+
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, g.LLM.Model)
7474
if cacheErr != nil {
7575
return cacheErr
7676
}
7777

78-
g.llm = cachedModel
78+
g.LLM.Model = cachedModel
7979
}
8080
return nil
8181
}
@@ -86,47 +86,6 @@ func (g *GoogleAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
8686
return
8787
}
8888

89-
func (g *GoogleAI) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
90-
messages := make([]llms.MessageContent, 0, len(r.Inputs))
91-
92-
for _, input := range r.Inputs {
93-
role := conversation.ConvertLangchainRole(input.Role)
94-
95-
messages = append(messages, llms.MessageContent{
96-
Role: role,
97-
Parts: []llms.ContentPart{
98-
llms.TextPart(input.Message),
99-
},
100-
})
101-
}
102-
103-
opts := []llms.CallOption{}
104-
105-
if r.Temperature > 0 {
106-
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
107-
}
108-
109-
resp, err := g.llm.GenerateContent(ctx, messages, opts...)
110-
if err != nil {
111-
return nil, err
112-
}
113-
114-
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
115-
116-
for i := range resp.Choices {
117-
outputs = append(outputs, conversation.ConversationResult{
118-
Result: resp.Choices[i].Content,
119-
Parameters: r.Parameters,
120-
})
121-
}
122-
123-
res = &conversation.ConversationResponse{
124-
Outputs: outputs,
125-
}
126-
127-
return res, nil
128-
}
129-
13089
func (g *GoogleAI) Close() error {
13190
return nil
13291
}

0 commit comments

Comments
 (0)