Skip to content

Commit 9d914ae

Browse files
authored
GEN-1292 new protos (#9)
* max_rows in gravity * Squashed 'protos/' changes from 8c94d7c..69137fc 69137fc Merge pull request #13 from macrocosm-os/gen-1222_add_missing_nebula_to_dataset_query c024b86 GEN-1222 Add missing nebula to dataset query ee2edaa Merge pull request #12 from macrocosm-os/optional-fields-2 ada0239 max_rows is not optional - revert this change d452c9a make some fields in the proto files optional git-subtree-dir: protos git-subtree-split: 69137fcf99e56523bb9a6f74e267caf080056e4f * gen protos * improve working example * bump version * add type
1 parent c619ec8 commit 9d914ae

File tree

17 files changed

+215
-153
lines changed

17 files changed

+215
-153
lines changed

examples/gravity_workflow_example.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@
2020

2121
class GravityWorkflow:
2222
def __init__(
23-
self, task_name: str, email: str, reddit_subreddit: str, x_hashtag: str
23+
self,
24+
task_name: str,
25+
email: str,
26+
reddit_subreddit: str,
27+
x_hashtag: str,
28+
max_rows: int,
2429
):
2530
self.task_name = task_name
2631
self.email = email
2732
self.reddit_subreddit = reddit_subreddit
2833
self.x_hashtag = x_hashtag
34+
self.max_rows = max_rows
2935
self.api_key = os.environ.get(
3036
"GRAVITY_API_KEY", os.environ.get("MACROCOSMOS_API_KEY")
3137
)
@@ -261,7 +267,9 @@ async def build_datasets(self, crawler_ids: Set[str]):
261267

262268
# Build dataset
263269
response = await self.client.gravity.BuildDataset(
264-
crawler_id=crawler_id, notification_requests=[notification]
270+
crawler_id=crawler_id,
271+
max_rows=self.max_rows,
272+
notification_requests=[notification],
265273
)
266274

267275
if response and response.dataset_id:
@@ -477,7 +485,22 @@ def get_user_input():
477485
if not task_name:
478486
task_name = "MyTestTask"
479487

480-
return email, reddit, x_hashtag, task_name
488+
# Get max rows with default
489+
while True:
490+
max_rows_input = input("Max rows per dataset [1000]: ").strip()
491+
if not max_rows_input:
492+
max_rows = 1000
493+
break
494+
try:
495+
max_rows = int(max_rows_input)
496+
if max_rows <= 0:
497+
print("Please enter a positive number.")
498+
continue
499+
break
500+
except ValueError:
501+
print("Please enter a valid number.")
502+
503+
return email, reddit, x_hashtag, task_name, max_rows
481504

482505

483506
async def main():
@@ -486,12 +509,12 @@ async def main():
486509
print("════════════════════════════════════════════════════")
487510

488511
# Get user input with defaults
489-
email, reddit, x_hashtag, task_name = get_user_input()
512+
email, reddit, x_hashtag, task_name, max_rows = get_user_input()
490513

491514
# Set up signal handlers for graceful shutdown
492515
loop = asyncio.get_running_loop()
493516

494-
workflow = GravityWorkflow(task_name, email, reddit, x_hashtag)
517+
workflow = GravityWorkflow(task_name, email, reddit, x_hashtag, max_rows)
495518

496519
# Register signal handlers
497520
for sig in (signal.SIGINT, signal.SIGTERM):

protos/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Macrocosmos Protobufs
2+
23
Protobufs used by Macrocosmos products
34

4-
# Usage
5+
## Usage
6+
57
Here are instructions on how to add and update the protos from your own project repo.
68

79
> ⚠️ These commands must be run from the target repo's root directory

protos/apex/v1/apex.proto

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@ message ChatCompletionRequest {
2323
// messages: the messages to generate completions for.
2424
repeated ChatMessage messages = 2;
2525
// seed: the seed to use for the completion.
26-
int64 seed = 3;
26+
optional int64 seed = 3;
2727
// task: the task to generate completions for (e.g. "InferenceTask").
28-
string task = 4;
28+
optional string task = 4;
2929
// model: the LLM name to use for the completion. (optional, suggest leaving this empty as not all LLMs are supported)
30-
string model = 5;
30+
optional string model = 5;
3131
// test_time_inference: whether to use test time inference.
32-
bool test_time_inference = 6;
32+
optional bool test_time_inference = 6;
3333
// mixture: whether to use a mixture of miners to create a slower but better answer.
34-
bool mixture = 7;
34+
optional bool mixture = 7;
3535
// sampling_parameters: the sampling parameters to use for the completion.
36-
SamplingParameters sampling_parameters = 8;
36+
optional SamplingParameters sampling_parameters = 8;
3737
// inference_mode: the inference mode to use for the completion.
38-
string inference_mode = 9;
38+
optional string inference_mode = 9;
3939
// json_format: whether to use JSON format for the completion.
40-
bool json_format = 10;
40+
optional bool json_format = 10;
4141
// stream: whether to stream the completion.
42-
bool stream = 11;
42+
optional bool stream = 11;
4343
// timeout: the timeout for the completion in seconds.
44-
int64 timeout = 12;
44+
optional int64 timeout = 12;
4545
}
4646

4747
// The sampling parameters for the completion.
@@ -52,7 +52,7 @@ message SamplingParameters {
5252
// top_p: the top_p to use for the completion.
5353
double top_p = 2;
5454
// top_k: the top_k to use for the completion.
55-
double top_k = 3;
55+
optional double top_k = 3;
5656
// max_new_tokens: the max_new_tokens to use for the completion.
5757
int64 max_new_tokens = 4;
5858
// do_sample: whether to do sample for the completion.
@@ -330,13 +330,13 @@ message WebRetrievalRequest {
330330
// search_query: the search query.
331331
string search_query = 2;
332332
// n_miners: the number of miners to use for the query.
333-
int64 n_miners = 3;
333+
optional int64 n_miners = 3;
334334
// n_results: the number of results to return.
335-
int64 n_results = 4;
335+
optional int64 n_results = 4;
336336
// max_response_time: the max response time to allow for the miners to respond in seconds.
337-
int64 max_response_time = 5;
337+
optional int64 max_response_time = 5;
338338
// timeout: the timeout for the web retrieval in seconds.
339-
int64 timeout = 6;
339+
optional int64 timeout = 6;
340340
}
341341

342342
// A web search result from Apex

protos/billing/v1/billing.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ service BillingService {
1212
// GetUsageRequest is the request message for getting the usage of the user's credits
1313
message GetUsageRequest {
1414
// product_type: the type of the product (i.e. "gravity")
15-
string product_type = 1;
15+
optional string product_type = 1;
1616
}
1717

1818
// ProductPlan is details of the subscription plan for a product

protos/gravity/v1/gravity.proto

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ message GravityTaskState {
115115
// GetGravityTasksRequest is the request message for listing gravity tasks for a user
116116
message GetGravityTasksRequest {
117117
// gravity_task_id: the ID of the gravity task (optional, if not provided, all gravity tasks for the user will be returned)
118-
string gravity_task_id = 1;
118+
optional string gravity_task_id = 1;
119119
// include_crawlers: whether to include the crawler states in the response
120-
bool include_crawlers = 2;
120+
optional bool include_crawlers = 2;
121121
}
122122

123123
// GetGravityTasksResponse is the response message for listing gravity tasks for a user
@@ -141,7 +141,7 @@ message NotificationRequest {
141141
// address: the address to send the notification to (only email addresses are supported currently)
142142
string address = 2;
143143
// redirect_url: the URL to include in the notication message that redirects the user to any built datasets
144-
string redirect_url = 3;
144+
optional string redirect_url = 3;
145145
}
146146

147147
// GetCrawlerRequest is the request message for getting a crawler
@@ -166,7 +166,7 @@ message CreateGravityTaskRequest {
166166
// that is automatically generated upon completion of the crawler is ready to download (optional)
167167
repeated NotificationRequest notification_requests = 3;
168168
// gravity_task_id: the ID of the gravity task (optional, default will generate a random ID)
169-
string gravity_task_id = 4;
169+
optional string gravity_task_id = 4;
170170
}
171171

172172
// CreateGravityTaskResponse is the response message for creating a new gravity task
@@ -194,6 +194,15 @@ message BuildDatasetResponse {
194194
Dataset dataset = 2;
195195
}
196196

197+
message Nebula {
198+
// error: nebula build error message
199+
string error = 1;
200+
// file_size_bytes: the size of the file in bytes
201+
int64 file_size_bytes = 2;
202+
// url: the URL of the file
203+
string url = 3;
204+
}
205+
197206
// Dataset contains the progress and results of a dataset build
198207
message Dataset {
199208
// crawler_workflow_id: the ID of the parent crawler for this dataset
@@ -212,6 +221,8 @@ message Dataset {
212221
repeated DatasetStep steps = 7;
213222
// total_steps: the total number of steps in the dataset build
214223
int64 total_steps = 8;
224+
// nebula: the details about the nebula that was built
225+
Nebula nebula = 9;
215226
}
216227

217228
// DatasetFile contains the details about a dataset file

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "macrocosmos"
3-
version = "1.0.0"
3+
version = "1.0.1"
44
description = "The official Python SDK for Macrocosmos"
55
readme = "README.md"
66
license = "Apache-2.0"

src/macrocosmos/generated/apex/v1/apex_p2p.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SamplingParameters(BaseModel):
3030
# top_p: the top_p to use for the completion.
3131
top_p: float = Field(default=0.0)
3232
# top_k: the top_k to use for the completion.
33-
top_k: float = Field(default=0.0)
33+
top_k: typing.Optional[float] = Field(default=0.0)
3434
# max_new_tokens: the max_new_tokens to use for the completion.
3535
max_new_tokens: int = Field(default=0)
3636
# do_sample: whether to do sample for the completion.
@@ -47,25 +47,25 @@ class ChatCompletionRequest(BaseModel):
4747
# messages: the messages to generate completions for.
4848
messages: typing.List[ChatMessage] = Field(default_factory=list)
4949
# seed: the seed to use for the completion.
50-
seed: int = Field(default=0)
50+
seed: typing.Optional[int] = Field(default=0)
5151
# task: the task to generate completions for (e.g. "InferenceTask").
52-
task: str = Field(default="")
52+
task: typing.Optional[str] = Field(default="")
5353
# model: the LLM name to use for the completion. (optional, suggest leaving this empty as not all LLMs are supported)
54-
model: str = Field(default="")
54+
model: typing.Optional[str] = Field(default="")
5555
# test_time_inference: whether to use test time inference.
56-
test_time_inference: bool = Field(default=False)
56+
test_time_inference: typing.Optional[bool] = Field(default=False)
5757
# mixture: whether to use a mixture of miners to create a slower but better answer.
58-
mixture: bool = Field(default=False)
58+
mixture: typing.Optional[bool] = Field(default=False)
5959
# sampling_parameters: the sampling parameters to use for the completion.
60-
sampling_parameters: SamplingParameters = Field(default_factory=SamplingParameters)
60+
sampling_parameters: typing.Optional[SamplingParameters] = Field(default_factory=SamplingParameters)
6161
# inference_mode: the inference mode to use for the completion.
62-
inference_mode: str = Field(default="")
62+
inference_mode: typing.Optional[str] = Field(default="")
6363
# json_format: whether to use JSON format for the completion.
64-
json_format: bool = Field(default=False)
64+
json_format: typing.Optional[bool] = Field(default=False)
6565
# stream: whether to stream the completion.
66-
stream: bool = Field(default=False)
66+
stream: typing.Optional[bool] = Field(default=False)
6767
# timeout: the timeout for the completion in seconds.
68-
timeout: int = Field(default=0)
68+
timeout: typing.Optional[int] = Field(default=0)
6969

7070
class TopLogprob(BaseModel):
7171
"""
@@ -372,13 +372,13 @@ class WebRetrievalRequest(BaseModel):
372372
# search_query: the search query.
373373
search_query: str = Field(default="")
374374
# n_miners: the number of miners to use for the query.
375-
n_miners: int = Field(default=0)
375+
n_miners: typing.Optional[int] = Field(default=0)
376376
# n_results: the number of results to return.
377-
n_results: int = Field(default=0)
377+
n_results: typing.Optional[int] = Field(default=0)
378378
# max_response_time: the max response time to allow for the miners to respond in seconds.
379-
max_response_time: int = Field(default=0)
379+
max_response_time: typing.Optional[int] = Field(default=0)
380380
# timeout: the timeout for the web retrieval in seconds.
381-
timeout: int = Field(default=0)
381+
timeout: typing.Optional[int] = Field(default=0)
382382

383383
class WebSearchResult(BaseModel):
384384
"""

0 commit comments

Comments
 (0)