Skip to content

Commit b485dd5

Browse files
Change signature of solve in nextmv.Model
1 parent 3d1961d commit b485dd5

File tree

6 files changed

+27
-35
lines changed

6 files changed

+27
-35
lines changed

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ Write the output data after a run is completed.
238238
#### Model
239239
240240
A decision model is a program that makes decisions, i.e.: solves decision
241-
problems. The model takes in an input (representing the problem data), options
242-
to configure the program, and returns an output, which is the solution to the
243-
decision problem. The `nextmv.Model` class is the base class for all models. It
244-
holds the necessary logic to handle all decisions.
241+
problems. The model takes in an input (representing the problem data and
242+
options) and returns an output, which is the solution to the decision problem.
243+
The `nextmv.Model` class is the base class for all models. It holds the
244+
necessary logic to handle all decisions.
245245
246246
When creating your own decision model, you must create a class that inherits
247247
from `nextmv.Model` and implement the `solve` method.
@@ -250,7 +250,7 @@ from `nextmv.Model` and implement the `solve` method.
250250
import nextmv
251251
252252
class YourCustomModel(nextmv.Model):
253-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
253+
def solve(self, input: nextmv.Input) -> nextmv.Output:
254254
"""Implement the logic to solve the decision problem here."""
255255
pass
256256
```
@@ -296,15 +296,15 @@ import highspy
296296
import nextmv
297297

298298
class DecisionModel(nextmv.Model):
299-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
299+
def solve(self, input: nextmv.Input) -> nextmv.Output:
300300
"""Solves the given problem and returns the solution."""
301301

302302
start_time = time.time()
303303

304304
# Creates the solver.
305305
solver = highspy.Highs()
306306
solver.silent() # Solver output ignores stdout redirect, silence it.
307-
solver.setOptionValue("time_limit", options.duration)
307+
solver.setOptionValue("time_limit", input.options.duration)
308308

309309
# Initializes the linear sums.
310310
weights = 0.0
@@ -330,7 +330,7 @@ class DecisionModel(nextmv.Model):
330330
item["item"] for item in items if solver.val(item["variable"]) > 0.9
331331
]
332332

333-
options.version = version("highspy")
333+
input.options.version = version("highspy")
334334

335335
statistics = nextmv.Statistics(
336336
run=nextmv.RunStatistics(duration=time.time() - start_time),
@@ -345,7 +345,7 @@ class DecisionModel(nextmv.Model):
345345
)
346346

347347
return nextmv.Output(
348-
options=options,
348+
options=input.options,
349349
solution={"items": chosen_items},
350350
statistics=statistics,
351351
)
@@ -361,7 +361,7 @@ import nextmv
361361

362362
model = DecisionModel()
363363
input = nextmv.Input(data=sample_input, options=options)
364-
output = model.solve(input, options)
364+
output = model.solve(input)
365365
print(json.dumps(output.solution, indent=2))
366366
```
367367

@@ -446,7 +446,7 @@ There are two strategies to push an application to the Nextmv Cloud:
446446
from nextmv import cloud
447447

448448
class CustomDecisionModel(nextmv.Model):
449-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
449+
def solve(self, input: nextmv.Input) -> nextmv.Output:
450450
"""Implement the logic to solve the decision problem here."""
451451
pass
452452

nextmv/__entrypoint__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from nextmv.cloud.manifest import Manifest
1313
from nextmv.input import load_local
14-
from nextmv.model import ModelConfiguration
1514
from nextmv.options import Options
1615
from nextmv.output import write_local
1716

@@ -28,18 +27,14 @@ def main() -> None:
2827
options = Options.from_parameters_dict(parameters_dict)
2928

3029
# Load the model.
31-
model_configuration = ModelConfiguration(
32-
name=manifest.python.model.name,
33-
options=options,
34-
)
3530
loaded_model = load_model(
36-
model_uri=model_configuration.name,
31+
model_uri=manifest.python.model.name,
3732
suppress_warnings=True,
3833
)
3934

4035
# Load the input and solve the model by using mlflow’s inference API.
4136
input = load_local(options=options)
42-
output = loaded_model.predict(input, params=options.to_dict())
37+
output = loaded_model.predict(input)
4338

4439
# Write the output.
4540
write_local(output)

nextmv/cloud/application.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,13 +712,13 @@ def push(
712712
# Define the model that makes decisions. This model uses the Nextroute
713713
# library to solve a vehicle routing problem.
714714
class DecisionModel(nextmv.Model):
715-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
715+
def solve(self, input: nextmv.Input) -> nextmv.Output:
716716
nextroute_input = nextroute.schema.Input.from_dict(input.data)
717-
nextroute_options = nextroute.Options.extract_from_dict(options.to_dict())
717+
nextroute_options = nextroute.Options.extract_from_dict(input.options.to_dict())
718718
nextroute_output = nextroute.solve(nextroute_input, nextroute_options)
719719
720720
return nextmv.Output(
721-
options=options,
721+
options=input.options,
722722
solution=nextroute_output.solutions[0].to_dict(),
723723
statistics=nextroute_output.statistics.to_dict(),
724724
)

nextmv/model.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,32 +70,30 @@ class Model:
7070
# Define the model that makes decisions. This model uses the Nextroute library
7171
# to solve a routing problem.
7272
class DecisionModel(nextmv.Model):
73-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
73+
def solve(self, input: nextmv.Input) -> nextmv.Output:
7474
nextroute_input = nextroute.schema.Input.from_dict(input.data)
75-
nextroute_options = nextroute.Options.extract_from_dict(options.to_dict())
75+
nextroute_options = nextroute.Options.extract_from_dict(input.options.to_dict())
7676
nextroute_output = nextroute.solve(nextroute_input, nextroute_options)
7777
7878
return nextmv.Output(
79-
options=options,
79+
options=input.options,
8080
solution=nextroute_output.solutions[0].to_dict(),
8181
statistics=nextroute_output.statistics.to_dict(),
8282
)
8383
```
8484
"""
8585

86-
def solve(self, input: Input, options: Options) -> Output:
86+
def solve(self, input: Input) -> Output:
8787
"""
8888
The `solve` method is the main entry point of your model. You must
89-
implement this method yourself. It receives a `nextmv.Input` and
90-
`nextmv.Options` and should process them to produce a `nextmv.Output`,
91-
which is the solution to the decision model/problem.
89+
implement this method yourself. It receives a `nextmv.Input` and should
90+
process it to produce a `nextmv.Output`, which is the solution to the
91+
decision model/problem.
9292
9393
Parameters
9494
----------
9595
input : Input
9696
The input data that the model will use to make a decision.
97-
options : Options
98-
The options that the model will use to make a decision.
9997
10098
Returns
10199
-------
@@ -142,7 +140,7 @@ class MLFlowModel(PythonModel):
142140
Nextmv `DecisionModel` into an `mlflow.pyfunc.PythonModel`. This
143141
class must comply with the inference API of mlflow, which is why it
144142
has a `predict` method. The translation happens by having this
145-
`predict` methos call the user-defined `solve` method of the
143+
`predict` method call the user-defined `solve` method of the
146144
`DecisionModel`.
147145
"""
148146

@@ -155,8 +153,7 @@ def predict(mlflow_self, context, model_input, params=None) -> Any:
155153
[python_function]: https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html
156154
"""
157155

158-
options = Options.from_dict(params)
159-
return self.solve(model_input, options)
156+
return self.solve(model_input)
160157

161158
# Some annoying logging from mlflow must be disabled.
162159
logging.disable(logging.CRITICAL)

tests/test_entrypoint/test_entrypoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class SimpleDecisionModel(nextmv.Model):
13-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
13+
def solve(self, input: nextmv.Input) -> nextmv.Output:
1414
return nextmv.Output(
1515
solution={"foo": "bar"},
1616
statistics={"baz": "qux"},

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class ModelForTesting(nextmv.Model):
99
"""Dummy decision model for testing purposes."""
1010

11-
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
11+
def solve(self, input: nextmv.Input) -> nextmv.Output:
1212
return nextmv.Output()
1313

1414

0 commit comments

Comments
 (0)