Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union, Tuple, TypeAlias
from typing import List, Optional, Union, Tuple, TypeAlias, Dict

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
Expand Down Expand Up @@ -194,13 +194,91 @@ def clear_input_components(self) -> None:
for attr, default in refinement_defaults.items():
setattr(self, attr, default)

# def run_model(
# self,
# load_parameters: list,
# refinement_parameters: list,
# project_name: str,
# rb_num: Optional[str] = None,
# user_x_limits: Optional[List[List[float]]] = None,
# ) -> Optional[int]:
# self.clear_input_components()
# if not self.initial_validation(project_name, load_parameters):
# return None
# self.set_components_from_inputs(load_parameters, refinement_parameters, project_name, rb_num)
# self.read_phase_files()
# self.generate_reflections_from_space_group()

# formatted_limits: Optional[List[List[float]]] = None
# # Ensure both elements are lists of floats and pass formatted limits to validate_x_limits
# if isinstance(user_x_limits, list) and len(user_x_limits) == 2:
# formatted_limits = [
# user_x_limits[0] if isinstance(user_x_limits[0], list) else [user_x_limits[0]],
# user_x_limits[1] if isinstance(user_x_limits[1], list) else [user_x_limits[1]],
# ]

# self.validate_x_limits(formatted_limits)
# if not self.further_validation():
# return None

# runtime = self.call_gsas2()
# if not runtime:
# return None
# runtime_str = runtime[1] # Extract the string part of the runtime tuple
# report_result = self.report_on_outputs(runtime_str)
# if report_result is not None:
# gsas_result_filepath, _ = report_result # Unpack the tuple
# else:
# logger.error("Failed to unpack the result from report_on_outputs.")
# return None
# gsas_result = gsas_result_filepath
# if not gsas_result:
# return None
# self.load_basic_outputs(gsas_result)

# if self.state.number_of_regions > self.state.number_histograms:
# return self.state.number_of_regions
# return self.state.number_histograms

def run_model(
self,
load_parameters: list,
refinement_parameters: list,
project_name: str,
rb_num: Optional[str] = None,
user_x_limits: Optional[List[List[float]]] = None,
) -> Optional[Dict[str, int]]:
"""
Returns a dictionary mapping data file names to their result counts
"""
data_files = load_parameters[2] # Extract data files list
num_hist = None

for data_file in data_files:
# Create unique project name for each file
file_basename = os.path.splitext(os.path.basename(data_file))[0]
individual_project_name = f"{project_name}_{file_basename}"

# Create modified load_parameters for single file
single_file_load_params = [
load_parameters[0], # instrument_files (reuse same)
load_parameters[1], # phase_filepaths (reuse same)
[data_file], # single data file
]

num_hist = self._run_single_refinement(
single_file_load_params, refinement_parameters, individual_project_name, rb_num, user_x_limits
)

return num_hist

def _run_single_refinement(
self,
load_parameters: list,
refinement_parameters: list,
project_name: str,
rb_num: Optional[str] = None,
user_x_limits: Optional[List[List[float]]] = None,
) -> Optional[int]:
self.clear_input_components()
if not self.initial_validation(project_name, load_parameters):
Expand Down Expand Up @@ -240,6 +318,9 @@ def run_model(
return self.state.number_of_regions
return self.state.number_histograms

# C:/MantidInstall/scripts/Engineering/ENGINX/phase_info/
# C:/Users/joy22959/Engineering_Mantid/User/test/Focus/

# ===============
# Prepare Inputs
# ===============
Expand Down Expand Up @@ -912,9 +993,13 @@ def create_lattice_parameter_table(self, test: bool = False) -> Optional[Union[N
# ===========

def load_focused_nxs_for_logs(self, filenames: List[str]) -> None:
if len(filenames) == 1 and "all_banks" in filenames[0]:
filenames = [filenames[0].replace("all_banks", "bank_1"), filenames[0].replace("all_banks", "bank_2")]
banks_filenames = []
for filename in filenames:
if "all_banks" in filename:
banks_filenames.extend([filename.replace("all_banks", "bank_1"), filename.replace("all_banks", "bank_2")])
else:
banks_filenames.append(filename)
for filename in banks_filenames:
filename = filename.replace(".gss", ".nxs")
ws_name = _generate_workspace_name(filename, self._suffix)
if ws_name not in self._data_workspaces.get_loaded_workpace_names():
Expand Down