Skip to content

Commit 6417b1d

Browse files
committed
Merge branch 'ft-openturns-doe-driver'
2 parents 06bd5c3 + daa22d5 commit 6417b1d

File tree

5 files changed

+162
-22
lines changed

5 files changed

+162
-22
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Driver for running model on design of experiments cases using OpenTURNS sampling methods
3+
"""
4+
from __future__ import print_function
5+
import numpy as np
6+
from six import iteritems
7+
8+
from openmdao.api import DOEDriver, OptionsDictionary
9+
from openmdao.drivers.doe_generators import DOEGenerator
10+
11+
OPENTURNS_NOT_INSTALLED = False
12+
try:
13+
import openturns as ot
14+
except ImportError:
15+
OPENTURNS_NOT_INSTALLED = True
16+
17+
18+
class OpenturnsMonteCarloDOEGenerator(DOEGenerator):
19+
LIMIT = 1e12
20+
21+
def __init__(self, n_samples=10, dist=None):
22+
super(OpenturnsMonteCarloDOEGenerator, self).__init__()
23+
24+
self.n_samples = n_samples
25+
self.distribution = dist
26+
self.called = False
27+
28+
def __call__(self, uncertain_vars, model=None):
29+
if self.distribution is None:
30+
dists = []
31+
for name, meta in iteritems(uncertain_vars):
32+
size = meta["size"]
33+
meta_low = meta["lower"]
34+
meta_high = meta["upper"]
35+
for j in range(size):
36+
if isinstance(meta_low, np.ndarray):
37+
p_low = meta_low[j]
38+
else:
39+
p_low = meta_low
40+
p_low = max(p_low, -self.LIMIT)
41+
42+
if isinstance(meta_high, np.ndarray):
43+
p_high = meta_high[j]
44+
else:
45+
p_high = meta_high
46+
p_high = min(p_high, self.LIMIT)
47+
48+
dists.append(ot.Uniform(p_low, p_high))
49+
self.distribution = ot.ComposedDistribution(dists)
50+
else:
51+
size = 0
52+
for name, meta in iteritems(uncertain_vars):
53+
size += meta["size"]
54+
if (size) != (self.distribution.getDimension()):
55+
raise RuntimeError(
56+
"Bad distribution dimension: should be equal to uncertain variables size {} "
57+
", got {}".format(size, self.distribution.getDimension())
58+
)
59+
samples = self.distribution.getSample(self.n_samples)
60+
self._cases = np.array(samples)
61+
self.called = True
62+
sample = []
63+
for i in range(self._cases.shape[0]):
64+
j = 0
65+
for name, meta in iteritems(uncertain_vars):
66+
size = meta["size"]
67+
sample.append((name, self._cases[i, j : j + size]))
68+
j += size
69+
yield sample
70+
71+
def get_cases(self):
72+
if not self.called:
73+
raise RuntimeError("Have to run the driver before getting cases")
74+
return self._cases
75+
76+
77+
class OpenturnsDOEDriver(DOEDriver):
78+
"""
79+
Baseclass for OpenTURNS design-of-experiments Drivers
80+
"""
81+
82+
def __init__(self, **kwargs):
83+
super(OpenturnsDOEDriver, self).__init__()
84+
85+
self.options.declare(
86+
"distribution",
87+
types=ot.ComposedDistribution,
88+
default=None,
89+
allow_none=True,
90+
desc="Joint distribution of uncertain variables",
91+
)
92+
self.options.declare(
93+
"n_samples", types=int, default=2, desc="number of sample to generate"
94+
)
95+
self.options.update(kwargs)
96+
97+
self.options["generator"] = OpenturnsMonteCarloDOEGenerator(
98+
self.options["n_samples"], self.options["distribution"]
99+
)
100+
101+
def _set_name(self):
102+
self._name = "OpenTURNS_DOE_MonteCarlo"
103+
104+
def get_cases(self):
105+
return self.options["generator"].get_cases()

openmdao_extensions/smt_doe_driver.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,3 @@ def __init__(self, **kwargs):
107107
def _set_name(self):
108108
self._name = "SMT_DOE_" + self.options["sampling_method_name"]
109109

110-
111-
class SmtDoeDriver(SmtDOEDriver):
112-
"""
113-
Deprecated. Use SmtDOEDriver.
114-
"""
115-
116-
def __init__(self, **kwargs):
117-
super(SmtDoeDriver, self).__init__(**kwargs)
118-
warn_deprecation("'SmtDoeDriver' is deprecated; " "use 'SmtDOEDriver' instead.")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import unittest
3+
from six import itervalues
4+
from openmdao.api import IndepVarComp, Problem, SqliteRecorder, CaseReader, DOEDriver
5+
from openmdao.test_suite.components.sellar import SellarProblem
6+
from openmdao_extensions.openturns_doe_driver import OpenturnsDOEDriver
7+
from openmdao_extensions.openturns_doe_driver import OPENTURNS_NOT_INSTALLED
8+
import openturns as ot
9+
10+
11+
class TestOpenturnsDoeDriver(unittest.TestCase):
12+
@staticmethod
13+
def run_driver(name, driver):
14+
pb = SellarProblem()
15+
case_recorder_filename = "test_openturns_doe_{}.sqlite".format(name)
16+
recorder = SqliteRecorder(case_recorder_filename)
17+
pb.driver = driver
18+
pb.driver.add_recorder(recorder)
19+
pb.setup()
20+
pb.run_driver()
21+
pb.cleanup()
22+
return pb, case_recorder_filename
23+
24+
def test_openturns_doe_driver(self):
25+
ns = 100
26+
driver = OpenturnsDOEDriver(n_samples=ns)
27+
TestOpenturnsDoeDriver.run_driver("mc", driver)
28+
cases = driver.get_cases()
29+
self.assertEqual((100, 3), cases.shape)
30+
31+
def test_openturns_doe_driver_with_dist(self):
32+
ns = 100
33+
dists = [ot.Normal(2, 1), ot.Normal(5, 1), ot.Normal(2, 1)]
34+
driver = OpenturnsDOEDriver(
35+
n_samples=ns, distribution=ot.ComposedDistribution(dists)
36+
)
37+
TestOpenturnsDoeDriver.run_driver("mc", driver)
38+
cases = driver.get_cases()
39+
self.assertEqual((100, 3), cases.shape)
40+
41+
def test_bad_dist(self):
42+
ns = 100
43+
dists = [ot.Normal(2, 1), ot.Normal(5, 1)]
44+
driver = OpenturnsDOEDriver(
45+
n_samples=ns, distribution=ot.ComposedDistribution(dists)
46+
)
47+
48+
with self.assertRaises(RuntimeError):
49+
TestOpenturnsDoeDriver.run_driver("mc", driver)
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

openmdao_extensions/tests/test_smt_doe_driver.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from openmdao.api import SqliteRecorder, CaseReader
44
from openmdao.test_suite.components.sellar import SellarProblem
5-
from openmdao_extensions.smt_doe_driver import SmtDOEDriver, SmtDoeDriver
5+
from openmdao_extensions.smt_doe_driver import SmtDOEDriver
66
from openmdao_extensions.smt_doe_driver import SMT_NOT_INSTALLED
77
from openmdao.utils.assert_utils import assert_warning
88

@@ -52,12 +52,5 @@ def test_smt_rand_doe_driver(self):
5252
n, SmtDOEDriver(sampling_method_name="Random", n_cases=n)
5353
)
5454

55-
@unittest.skipIf(SMT_NOT_INSTALLED, "SMT library is not installed")
56-
def test_deprecated(self):
57-
msg = "'SmtDoeDriver' is deprecated; " "use 'SmtDOEDriver' instead."
58-
with assert_warning(DeprecationWarning, msg):
59-
SmtDoeDriver()
60-
61-
6255
if __name__ == "__main__":
6356
unittest.main()

setup.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
Intended Audience :: Developers
2020
License :: OSI Approved :: Apache Software License
2121
Programming Language :: Python
22-
Programming Language :: Python :: 2.7
23-
Programming Language :: Python :: 3.6
24-
Programming Language :: Python :: 3.7
22+
Programming Language :: Python :: 3
2523
Topic :: Software Development
2624
Topic :: Scientific/Engineering
2725
Operating System :: Microsoft :: Windows
@@ -31,7 +29,7 @@
3129

3230
metadata = dict(
3331
name="openmdao_extensions",
34-
version="0.3.2",
32+
version="0.4.0",
3533
description="Additional solvers and drivers for OpenMDAO framework",
3634
long_description=long_description,
3735
long_description_content_type="text/markdown",
@@ -41,7 +39,7 @@
4139
classifiers=[_f for _f in CLASSIFIERS.split("\n") if _f],
4240
packages=["openmdao_extensions"],
4341
install_requires=["openmdao"],
44-
python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*",
42+
python_requires="!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*",
4543
zip_safe=True,
4644
url="https://github.com/OneraHub/opendmao_extensions",
4745
download_url="https://github.com/OneraHub/openmdao_extensions/releases",

0 commit comments

Comments
 (0)