Skip to content

Commit cce3ca8

Browse files
authored
Merge pull request #3 from pedMatias/feature/multi-model-support
[feature & bug fix] Update openai library and add support to gpt4
2 parents 9d5afe5 + be10b70 commit cce3ca8

6 files changed

Lines changed: 46 additions & 23 deletions

File tree

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ plot = PlotAI(df)
8989
plot.make("make a scatter plot")
9090
```
9191

92+
By default the library will use '*gpt-3.5-turbo*'. You can use different OpenAI models:
93+
94+
```python
95+
# import PlotAI
96+
from plotai import PlotAI
97+
98+
# create PlotAI object, pass pandas DataFrame as an argument
99+
plot = PlotAI(df, model_version="gpt-4")
100+
101+
# make a plot, just tell what you want
102+
plot.make("make a scatter plot")
103+
```
104+
92105
## More examples
93106

94107
#### Analyze the GPD dataset

plotai/llm/openai.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import openai
33

44
from dotenv import load_dotenv
5+
56
load_dotenv()
67

78

8-
class ChatGPT():
9+
class ChatGPT:
910

1011
temperature = 0
1112
max_tokens = 1000
@@ -14,14 +15,15 @@ class ChatGPT():
1415
presence_penalty = 0.6
1516
model = "gpt-3.5-turbo"
1617

17-
18-
def __init__(self):
18+
def __init__(self, model: str):
1919
api_key = os.environ.get("OPENAI_API_KEY")
2020
if api_key is None:
21-
raise Exception("Please set OPENAI_API_KEY environment variable."
22-
"You can obtain API key from https://platform.openai.com/account/api-keys")
21+
raise Exception(
22+
"Please set OPENAI_API_KEY environment variable."
23+
"You can obtain API key from https://platform.openai.com/account/api-keys"
24+
)
2325
openai.api_key = api_key
24-
26+
self.model = model
2527

2628
@property
2729
def _default_params(self):
@@ -35,6 +37,7 @@ def _default_params(self):
3537
}
3638

3739
def chat(self, prompt):
40+
client = openai.OpenAI()
3841

3942
params = {
4043
**self._default_params,
@@ -45,5 +48,5 @@ def chat(self, prompt):
4548
}
4649
],
4750
}
48-
response = openai.ChatCompletion.create(**params)
49-
return response["choices"][0]["message"]["content"]
51+
response = client.chat.completions.create(**params)
52+
return response.choices[0].message.content

plotai/plotai.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from plotai.code.executor import Executor
66
from plotai.code.logger import Logger
77

8+
89
class PlotAI:
910

10-
def __init__(self, *args, **kwargs):
11+
def __init__(self, model_version: str = "gpt-3.5-turbo", *args, **kwargs):
1112

13+
# OpenAI Model Version
14+
self.model_version = model_version
15+
# DataFrame to plot
1216
self.df, self.x, self.y, self.z = None, None, None, None
1317
if len(args) > 1:
1418
for i in range(len(args)):
@@ -34,13 +38,11 @@ def __init__(self, *args, **kwargs):
3438
setattr(self, k, kwargs[k])
3539

3640
def make(self, prompt):
37-
38-
df, x, y, z = self.df, self.x, self.y, self.z
39-
p = Prompt(prompt, self.df, self.x, self.y, self.z)
41+
p = Prompt(prompt, self.df, self.x, self.y, self.z)
4042

4143
Logger().log({"title": "Prompt", "details": p.value})
4244

43-
response = ChatGPT().chat(p.value)
45+
response = ChatGPT(model=self.model_version).chat(p.value)
4446

4547
Logger().log({"title": "Response", "details": response})
4648

@@ -49,8 +51,7 @@ def make(self, prompt):
4951
if error is not None:
5052
Logger().log({"title": "Error in code execution", "details": error})
5153

52-
53-
# p_again = Prompt(prompt, self.df, self.x, self.y, self.z, previous_code=response, previous_error=error)
54+
# p_again = Prompt(prompt, self.df, self.x, self.y, self.z, previous_code=response, previous_error=error)
5455

5556
# Logger().log({"title": "Prompt with fix", "details": p_again.value})
5657

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
matplotlib
2-
pandas
3-
numpy
4-
openai
5-
python-dotenv
1+
matplotlib~=3.8.3
2+
pandas~=2.2.0
3+
numpy~=1.26.4
4+
openai~=1.12.0
5+
python-dotenv~=1.0.1

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="plotai",
13-
version="0.0.2",
13+
version="0.0.3",
1414
description="Create plots in Python with AI",
1515
long_description=long_description,
1616
long_description_content_type="text/markdown",
@@ -21,7 +21,7 @@
2121
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
2222
install_requires=open("requirements.txt").readlines(),
2323
include_package_data=True,
24-
python_requires='>=3.7.1',
24+
python_requires=">=3.7.1",
2525
classifiers=[
2626
"Programming Language :: Python",
2727
"Programming Language :: Python :: 3.7",
@@ -37,6 +37,6 @@
3737
"matplotlib",
3838
"llm",
3939
"openai",
40-
"mljar"
40+
"mljar",
4141
],
4242
)

tests/test_plotai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@ def test_pass_data(self):
1717
df2 = pd.DataFrame({"x":np.random.rand(rows), "y": np.random.rand(rows)})
1818
plot = PlotAI(df=df2)
1919
#plot.make("Plot a scatter plot")
20+
21+
def test_gpt4(self):
22+
rows = 100
23+
df2 = pd.DataFrame({"x":np.random.rand(rows), "y": np.random.rand(rows)})
24+
plot = PlotAI(df=df2, model_version="gpt4")
25+
#plot.make("Plot a scatter plot")

0 commit comments

Comments
 (0)