Skip to content

Commit d985738

Browse files
committed
Simplify get_cross_section from an established AtomInfo
1 parent 6ea7217 commit d985738

File tree

4 files changed

+14
-37
lines changed

4 files changed

+14
-37
lines changed

Framework/PythonInterface/plugins/algorithms/Abins.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,7 @@ def _fill_s_1d_workspace(self, *, s_points: np.ndarray, workspace: str, species:
230230
:param species: atom/isotope identity and data
231231
"""
232232
if species is not None:
233-
s_points = (
234-
s_points
235-
* self._scale
236-
* self.get_cross_section(
237-
scattering=self._scale_by_cross_section, protons_number=species.z_number, nucleons_number=species.nucleons_number
238-
)
239-
)
233+
s_points = s_points * self._scale * self.get_cross_section(scattering=self._scale_by_cross_section, species=species)
240234
dim = 1
241235
length = s_points.size
242236

scripts/abins/abins2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,7 @@ def _fill_s_1d_workspace(self, s_points=None, workspace=None, species: AtomInfo
217217
:param workspace: workspace to be filled with S
218218
"""
219219
if species is not None:
220-
s_points = s_points * self.get_cross_section(
221-
scattering=self._scale_by_cross_section, protons_number=species.z_number, nucleons_number=species.nucleons_number
222-
)
220+
s_points = s_points * self.get_cross_section(scattering=self._scale_by_cross_section, species=species)
223221
dim = 1
224222
length = s_points.size
225223

scripts/abins/abinsalgorithm.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import os
1414
from pathlib import Path
1515
import re
16-
from typing import Dict, Iterable, List, Optional, Tuple, Union
16+
from typing import Dict, Iterable, List, Literal, Tuple, Union
1717

1818
import yaml
1919

@@ -664,29 +664,18 @@ def write_workspaces_to_ascii(scale: float = 1.0, *, ws_name: str) -> None:
664664
)
665665

666666
@staticmethod
667-
def get_cross_section(scattering: str = "Total", nucleons_number: Optional[int] = None, *, protons_number: int) -> float:
667+
def get_cross_section(scattering: Literal["Total", "Incoherent", "Coherent"], species: AtomInfo) -> float:
668668
"""
669669
Calculates cross section for the given element.
670670
:param scattering: Type of cross-section: 'Incoherent', 'Coherent' or 'Total'
671-
:param protons_number: number of protons in the given type fo atom
672-
:param nucleons_number: number of nucleons in the given type of atom
671+
:param species: Data for atom/isotope type
673672
:returns: cross section for that element
674673
"""
675-
if nucleons_number is not None:
676-
try:
677-
atom = Atom(a_number=nucleons_number, z_number=protons_number)
678-
# isotopes are not implemented for all elements so use different constructor in that cases
679-
except RuntimeError:
680-
logger.warning(f"Could not find data for isotope {nucleons_number}, " f"using default values for {protons_number} protons.")
681-
atom = Atom(z_number=protons_number)
682-
else:
683-
atom = Atom(z_number=protons_number)
684-
685674
scattering_keys = {"Incoherent": "inc_scatt_xs", "Coherent": "coh_scatt_xs", "Total": "tot_scatt_xs"}
686-
cross_section = atom.neutron()[scattering_keys[scattering]]
675+
cross_section = species.neutron_data[scattering_keys[scattering]]
687676

688677
if isnan(cross_section):
689-
raise ValueError(f"Found NaN cross-section for {atom.symbol} with {nucleons_number} nucleons.")
678+
raise ValueError(f"Found NaN cross-section for {species.symbol} with {species.nucleons_number} nucleons.")
690679

691680
return cross_section
692681

scripts/test/Abins/AbinsAlgorithmTest.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Import mantid.simpleapi first, otherwise we get circular import
1111
import mantid.simpleapi # noqa: F401
1212

13-
from abins.abinsalgorithm import AbinsAlgorithm
13+
from abins.abinsalgorithm import AbinsAlgorithm, AtomInfo
1414

1515

1616
class AtomsDataTest(unittest.TestCase):
@@ -19,16 +19,12 @@ class AtomsDataTest(unittest.TestCase):
1919
def test_cross_section(self):
2020
"""Get cross section from nucleus information"""
2121

22-
for scattering, nucleons_number, protons_number, expected in [
23-
("Incoherent", 67, 30, 0.28),
24-
("Coherent", None, 30, 4.054),
25-
("Total", None, 1, 82.02),
22+
for scattering, nucleons_number, symbol, expected in [
23+
("Incoherent", 67, "Zn", 0.28),
24+
("Coherent", 0, "Zn", 4.054),
25+
("Total", 0, "H", 82.02),
2626
]:
27-
xc = AbinsAlgorithm.get_cross_section(
28-
scattering=scattering,
29-
nucleons_number=nucleons_number,
30-
protons_number=protons_number,
31-
)
27+
xc = AbinsAlgorithm.get_cross_section(scattering=scattering, species=AtomInfo(mass=float(nucleons_number), symbol=symbol))
3228

3329
self.assertAlmostEqual(xc, expected)
3430

@@ -37,4 +33,4 @@ def test_get_bad_cross_section(self):
3733

3834
with self.assertRaisesRegex(ValueError, "Found NaN cross-section for Zn with 65 nucleons"):
3935
# Zn65 is unstable and has no recorded cross section values
40-
AbinsAlgorithm.get_cross_section(nucleons_number=65, protons_number=30)
36+
AbinsAlgorithm.get_cross_section("Total", AtomInfo(symbol="Zn", mass=65.0))

0 commit comments

Comments
 (0)