diff --git a/.LightshowAI/bin/python b/.LightshowAI/bin/python new file mode 120000 index 0000000..b8a0adb --- /dev/null +++ b/.LightshowAI/bin/python @@ -0,0 +1 @@ +python3 \ No newline at end of file diff --git a/.LightshowAI/bin/python3 b/.LightshowAI/bin/python3 new file mode 120000 index 0000000..ae65fda --- /dev/null +++ b/.LightshowAI/bin/python3 @@ -0,0 +1 @@ +/usr/bin/python3 \ No newline at end of file diff --git a/.LightshowAI/bin/python3.10 b/.LightshowAI/bin/python3.10 new file mode 120000 index 0000000..b8a0adb --- /dev/null +++ b/.LightshowAI/bin/python3.10 @@ -0,0 +1 @@ +python3 \ No newline at end of file diff --git a/.LightshowAI/lib64 b/.LightshowAI/lib64 new file mode 120000 index 0000000..7951405 --- /dev/null +++ b/.LightshowAI/lib64 @@ -0,0 +1 @@ +lib \ No newline at end of file diff --git a/.LightshowAI/pyvenv.cfg b/.LightshowAI/pyvenv.cfg new file mode 100644 index 0000000..0537ffc --- /dev/null +++ b/.LightshowAI/pyvenv.cfg @@ -0,0 +1,3 @@ +home = /usr/bin +include-system-site-packages = false +version = 3.10.12 diff --git a/html/index.html b/html/index.html index f2202e8..45b1a3b 100644 --- a/html/index.html +++ b/html/index.html @@ -99,7 +99,7 @@
- +

Lightshowai is a web user interface to predict the x-ray absorption near edge structure (XANES) spectra using neural network models. diff --git a/lightshowai/models.py b/lightshowai/models.py index 7cd0916..18d59e9 100644 --- a/lightshowai/models.py +++ b/lightshowai/models.py @@ -172,4 +172,4 @@ def predict(structure, absorbing_site, spectroscopy_type): element=absorbing_site, spectroscopy_type=spectroscopy_type ).predict(structure) result = {ii: spec[ii] for ii in site_idxs} - return result + return result \ No newline at end of file diff --git a/lightshowai/xas_ui.py b/lightshowai/xas_ui.py index 94b3f50..c7bd319 100644 --- a/lightshowai/xas_ui.py +++ b/lightshowai/xas_ui.py @@ -1,15 +1,41 @@ -from base64 import b64encode +import numpy as np + +def _patch_pymatgen_neighbors(): + try: + from pymatgen.optimization import neighbors as pmg_neighbors + _original_find_points = pmg_neighbors.find_points_in_spheres + + def _patched_find_points_in_spheres( + all_coords, center_coords, r, pbc, lattice, tol=1e-8 + ): + pbc = np.asarray(pbc, dtype=np.int64) + return _original_find_points( + all_coords, center_coords, r, pbc, lattice, tol + ) + + pmg_neighbors.find_points_in_spheres = _patched_find_points_in_spheres + print("Applied Windows int64 compatibility patch for pymatgen") + except Exception as e: + print(f"Warning: Could not apply pymatgen patch: {e}") + +_patch_pymatgen_neighbors() + + +from base64 import b64encode, b64decode import os +import io import tempfile import pathlib from zipfile import ZipFile +import re import pandas as pd import numpy as np - import dash from dash import dcc, html import plotly.express as px -from dash.dependencies import Input, Output, State +import plotly.graph_objects as go + +from dash.dependencies import Input, Output, State, ALL from dash.exceptions import PreventUpdate from pymatgen.core.structure import Structure from mp_api.client import MPRester @@ -23,6 +49,7 @@ ) from lightshowai.models import predict +from lightshowai.postprocess import compare_utils app = dash.Dash(prevent_initial_callbacks=True, title="OmniXAS@Lightshow.ai", @@ -34,6 +61,38 @@ show_export_button=False) search_component = ctc.SearchComponent(id='mpid_search') upload_component = ctc.StructureMoleculeUploadComponent(id='file_loader') + +# Combined single/multiple structure upload component +batch_upload_component = dcc.Upload( + id='batch_structure_upload', + children=html.Div([ + html.Div([ + 'Drag & Drop or ', + html.A('Select File(s)', style={'color': '#333', 'cursor': 'pointer', 'fontWeight': '500', 'textDecoration': 'underline'}) + ]) + ]), + style={ + 'width': '100%', + 'height': '50px', + 'lineHeight': '50px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderColor': '#d0d0d0', + 'borderRadius': '6px', + 'textAlign': 'center', + 'backgroundColor': '#fafafa', + 'cursor': 'pointer', + 'color': '#666', + 'fontSize': '12px', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + }, + multiple=True, # Allow single or multiple file selection + accept='.cif,.vasp,.poscar,.json' +) + +# Store for batch processing status +batch_processing_store = dcc.Store(id='batch_processing_store', data={'status': 'idle', 'processed': 0, 'total': 0}) + xas_plot = dcc.Graph(id='xas_plot') st_source = html.H1(id='st_source', children='No structure loaded yet') @@ -44,23 +103,894 @@ xas_model_names = [f'{el} FEFF' for el in all_elements] + ['Ti VASP', 'Cu VASP'] absorber_dropdown = dcc.Dropdown(xas_model_names, clearable=False, value='Ti VASP', id='absorber') -onmixas_layout = Columns([ - Column(Box([Loading(search_component.layout()), - Loading(upload_component.layout()), - html.Br(), html.Br(), - st_source, - html.Br(), html.Br(), - html.Div("Element and Theory:"), - Loading(absorber_dropdown)], - style={"width": "350px"}), narrow=True), - Column(Loading(struct_component.layout(size="70%"))), - Column(Box([xas_plot, - html.Button("Download POSCAR and Spectrum", id="download_btn"), - dcc.Download(id="download_sink")])) +# All available metrics for display +ALL_METRICS = ["coss_deriv", "pearson", "spearman", "coss", "kendalltaub", "normed_wasserstein"] + +# Short display names for table headers +METRIC_SHORT_NAMES = { + "coss_deriv": "Cos(∂)", + "pearson": "Pearson", + "spearman": "Spearman", + "coss": "Cosine", + "kendalltaub": "Kendall", + "normed_wasserstein": "Wasser.", +} + + +def get_spectrum_match_score(predicted_spectrum, exp_spectrum, element): + """ + Compare predicted spectrum against experimental spectrum using + lightshow.postprocess.compare_utils.compare_between_spectra. + + Returns comparison_range which is the energy range used for comparison. + """ + try: + ene = ene_grid[element] + ml_spectrum = np.column_stack((ene, predicted_spectrum)) + exp_energy = np.array(exp_spectrum['energy']) + exp_absorption = np.array(exp_spectrum['absorption']) + expt_spectrum = np.column_stack((exp_energy, exp_absorption)) + + opt_metric = "coss_deriv" + other_metrics = ["pearson", "spearman", "coss", "kendalltaub", "coss_deriv", "normed_wasserstein"] + + erange = 35 + erange_threshold = 0.04 + truncation_strategy = "from_spect2" + erange_lbound_delta = 5 + + correlations, shift = compare_utils.compare_between_spectra( + expt_spectrum, + ml_spectrum, + erange=erange, + erange_threshold=erange_threshold, + erange_lbound_delta=erange_lbound_delta, + truncation_strategy=truncation_strategy, + grid_interpolator=compare_utils.gridInterpolatorFixedSpacing(0.25), + output_correlations=other_metrics, + opt_strategy="grid_search_and_local_opt", + accuracy=0.1, + method=opt_metric, + norm_y_axis=True + ) + + # Calculate the comparison range + # The shift returned aligns ML spectrum to experimental spectrum + # ML spectrum energy range after shift: (ene + shift) + # The comparison uses erange (35 eV) starting from edge + + # For ML spectrum (spect2), find where edge starts + ml_y_normalized = (ml_spectrum[:, 1] - np.min(ml_spectrum[:, 1])) / (np.max(ml_spectrum[:, 1]) - np.min(ml_spectrum[:, 1])) + ml_edge_idx = np.argmax(ml_y_normalized > erange_threshold) + ml_edge_energy = ml_spectrum[ml_edge_idx, 0] + + # The comparison range in the EXPERIMENTAL spectrum's energy scale + # ML edge energy + shift = where ML edge aligns in exp energy scale + comparison_start = ml_edge_energy + shift + comparison_end = comparison_start + erange + + # Debug output + # print(f"=== Comparison Range Debug ===") + # print(f"ML edge energy: {ml_edge_energy:.1f} eV") + # print(f"Shift: {shift:.2f} eV") + # print(f"Comparison range: {comparison_start:.1f} - {comparison_end:.1f} eV") + + score = correlations.get(opt_metric, 0.0) + if np.isnan(score) or np.isinf(score): + score = 0.0 + + return { + 'score': round(float(score), 3), + 'correlations': {k: round(float(v), 3) if not (np.isnan(v) or np.isinf(v)) else 0.0 + for k, v in correlations.items()}, + 'shift': round(float(shift), 2), + 'comparison_range': (round(float(comparison_start), 1), round(float(comparison_end), 1)) + } + + except Exception as e: + print(f"Error in spectrum matching: {e}") + import traceback + traceback.print_exc() + return { + 'score': 0.0, + 'correlations': {}, + 'shift': 0.0, + 'comparison_range': None + } + + +# Store for matching results +matching_results_store = dcc.Store(id='matching_results_store', data=[]) +structure_scores_store = dcc.Store(id='structure_scores_store', data=[]) +comparison_range_store = dcc.Store(id='comparison_range_store', data=None) +selected_spectra_store = dcc.Store(id='selected_spectra_store', data=[]) +sort_metric_store = dcc.Store(id='sort_metric_store', data='coss_deriv') + +# Custom experimental spectrum upload component +exp_upload_component = dcc.Upload( + id='exp_spectrum_upload', + children=html.Div([ + html.Div([ + 'Drag and Drop or ', + html.A('Select File', style={'color': '#333', 'cursor': 'pointer', 'fontWeight': '500', 'textDecoration': 'underline'}) + ]) + ]), + style={ + 'width': '100%', + 'height': '50px', + 'lineHeight': '50px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderColor': '#d0d0d0', + 'borderRadius': '6px', + 'textAlign': 'center', + 'backgroundColor': '#fafafa', + 'cursor': 'pointer', + 'color': '#666', + 'fontSize': '12px', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + }, + multiple=False, + accept='.dat,.mat,.csv,.xdi' +) + +# Input for material name +exp_material_name_input = dcc.Input( + id='exp_material_name', + type='text', + placeholder='e.g., Anatase TiO2', + style={ + 'width': '100%', + 'padding': '10px 12px', + 'borderRadius': '6px', + 'border': '1px solid #ddd', + 'fontSize': '12px', + 'boxSizing': 'border-box', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + } +) + +# Store for raw file data (before column selection) +exp_raw_data_store = dcc.Store(id='exp_raw_data_store', data=None) + +# Store for column definitions +exp_columns_store = dcc.Store(id='exp_columns_store', data=None) + +# Store for final experimental spectrum data +exp_spectrum_store = dcc.Store(id='exp_spectrum_store', data=None) + +# Dynamic column definition area +exp_column_definition_area = html.Div( + id='exp_column_definition_area', + children=[], + style={'marginTop': '10px'} +) + +# Dropdown for X-axis column selection +exp_x_axis_dropdown = dcc.Dropdown( + id='exp_x_axis_dropdown', + options=[], + placeholder='Select X-axis column', + style={'marginBottom': '8px'} +) + +# Dropdown for Y-axis column selection +exp_y_axis_dropdown = dcc.Dropdown( + id='exp_y_axis_dropdown', + options=[], + placeholder='Select Y-axis column', + style={'marginBottom': '8px'} +) + +# Button to apply column selection and plot +exp_apply_btn = html.Button( + "Apply & Plot", + id="exp_apply_btn", + style={ + 'padding': '8px 16px', + 'fontSize': '12px', + 'border': 'none', + 'borderRadius': '6px', + 'backgroundColor': '#333', + 'color': 'white', + 'cursor': 'pointer', + 'fontWeight': '500', + 'marginRight': '8px', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + } +) + +# Display for uploaded experimental file info +exp_file_info = html.Div(id='exp_file_info', children='No experimental spectrum loaded', + style={ + 'fontSize': '11px', + 'color': '#888', + 'marginTop': '10px', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + }) + +# Button to clear experimental spectrum +clear_exp_btn = html.Button("Clear", id="clear_exp_btn", + style={ + 'fontSize': '12px', + 'padding': '8px 16px', + 'border': '1px solid #ddd', + 'borderRadius': '6px', + 'backgroundColor': 'white', + 'color': '#666', + 'cursor': 'pointer', + 'fontFamily': "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + }) + +# Common styles +base_font = "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" + +section_header_style = { + "fontWeight": "600", + "fontSize": "13px", + "color": "#333", + "marginBottom": "14px", + "paddingBottom": "10px", + "borderBottom": "1px solid #eee", + "fontFamily": base_font, + "letterSpacing": "0.2px" +} + +column_header_style = { + "fontWeight": "600", + "fontSize": "13px", + "color": "#333", + "marginBottom": "14px", + "paddingBottom": "10px", + "borderBottom": "1px solid #eee", + "fontFamily": base_font, + "letterSpacing": "0.2px" +} + +input_label_style = { + "fontSize": "12px", + "color": "#666", + "marginBottom": "6px", + "fontWeight": "500", + "fontFamily": base_font +} + +card_style = { + "backgroundColor": "white", + "borderRadius": "8px", + "padding": "18px", + "marginBottom": "12px", + "border": "1px solid #e8e8e8" +} + +button_primary_style = { + 'padding': '10px 20px', + 'fontSize': '13px', + 'border': 'none', + 'borderRadius': '6px', + 'backgroundColor': '#333', + 'color': 'white', + 'cursor': 'pointer', + 'fontWeight': '500', + 'marginRight': '8px', + 'fontFamily': base_font +} + +button_secondary_style = { + 'padding': '8px 16px', + 'fontSize': '12px', + 'border': '1px solid #ddd', + 'borderRadius': '6px', + 'backgroundColor': 'white', + 'color': '#666', + 'cursor': 'pointer', + 'fontFamily': base_font +} + +onmixas_layout = html.Div([ + # Main content area + Columns([ + # Column 1: Input Controls + Column( + html.Div([ + # Experimental Spectrum Upload Card + html.Div([ + html.Div("Upload Experimental Spectrum", style=section_header_style), + + html.Div("Material Name (optional):", style=input_label_style), + exp_material_name_input, + + html.Div( + "Accepted formats: .csv, .dat, .mat, .xdi", + style={"fontSize": "11px", "color": "#999", "marginTop": "10px", "marginBottom": "8px"} + ), + + exp_upload_component, + exp_column_definition_area, + + html.Div( + id='exp_column_selection_area', + children=[ + html.Div("Select columns to plot:", style={**input_label_style, "marginTop": "12px"}), + html.Div([ + html.Div([ + html.Span("X-axis:", style={"fontSize": "11px", "display": "block", "marginBottom": "4px", "color": "#666"}), + exp_x_axis_dropdown, + ], style={"display": "inline-block", "width": "48%", "marginRight": "4%", "verticalAlign": "top"}), + html.Div([ + html.Span("Y-axis:", style={"fontSize": "11px", "display": "block", "marginBottom": "4px", "color": "#666"}), + exp_y_axis_dropdown, + ], style={"display": "inline-block", "width": "48%", "verticalAlign": "top"}), + ]), + html.Div([ + exp_apply_btn, + clear_exp_btn, + ], style={"marginTop": "12px"}), + ], + style={"display": "none"} + ), + + exp_file_info, + exp_raw_data_store, + exp_columns_store, + exp_spectrum_store, + ], style=card_style), + + # Load Structure Card + html.Div([ + html.Div("Load Structure", style=section_header_style), + + # Single structure search + html.Div("Search by Materials Project ID:", style={**input_label_style, "marginBottom": "8px"}), + Loading(search_component.layout()), + + html.Hr(style={"margin": "15px 0", "border": "none", "borderTop": "1px solid #eee"}), + + # Combined single/multiple file upload + html.Div("Upload structure file(s):", style={**input_label_style, "marginBottom": "4px"}), + html.Div( + "Single or multiple files • Supported: .cif, .vasp, .poscar, .json", + style={"fontSize": "10px", "color": "#999", "marginBottom": "8px"} + ), + batch_upload_component, + batch_processing_store, + + # Processing status + html.Div(id='batch_status', children='', style={ + "fontSize": "11px", + "color": "#666", + "marginTop": "8px", + "fontFamily": base_font + }), + + html.Div(st_source, style={"marginTop": "10px"}), + ], style=card_style), + + # XAS Model Prediction Card + html.Div([ + html.Div("XAS Model Prediction", style=section_header_style), + Loading(absorber_dropdown), + ], style=card_style), + + ], style={"width": "100%"}), + narrow=True, + ), + + # Column 2: Crystal Structure Viewer + Column( + html.Div([ + html.Div([ + html.Div("Crystal Structure Viewer", style=column_header_style), + Loading(struct_component.layout(size="100%")), + ], style={ + "backgroundColor": "white", + "borderRadius": "8px", + "padding": "18px", + "border": "1px solid #e8e8e8", + "minHeight": "500px" + }) + ]), + style={"flex": "1", "minWidth": "400px", "padding": "0 6px"} + ), + + # Column 3: Spectrum Analysis + Column( + html.Div([ + html.Div([ + html.Div("XANES Spectrum Analysis", style=column_header_style), + xas_plot, + + # Energy shift slider + html.Div([ + html.Div([ + html.Span("Shift Predicted Spectrum: ", style={"fontSize": "12px", "color": "#666", "fontFamily": base_font}), + html.Span(id='energy_shift_display', children="0.0 eV", + style={"fontSize": "12px", "fontWeight": "600", "color": "#333", "fontFamily": base_font}), + ], style={"marginTop": "15px", "marginBottom": "8px"}), + dcc.Slider( + id='energy_shift_slider', + min=-50, + max=50, + step=0.01, + value=0, + marks=None, + tooltip={"placement": "bottom", "always_visible": False}, + updatemode='drag', + included=False, + ), + html.Div([ + html.Span("-50 eV", style={"fontSize": "10px", "color": "#999", "fontFamily": base_font}), + html.Span("0", style={"fontSize": "10px", "color": "#999", "position": "absolute", "left": "50%", "transform": "translateX(-50%)", "fontFamily": base_font}), + html.Span("+50 eV", style={"fontSize": "10px", "color": "#999", "fontFamily": base_font}), + ], style={"display": "flex", "justifyContent": "space-between", "position": "relative", "marginTop": "-5px"}), + html.Button("Reset Shift", id="reset_shift_btn", style={**button_secondary_style, "marginTop": "10px", "fontSize": "11px", "padding": "6px 14px"}), + ], id='energy_shift_container', style={"padding": "0 10px"}), + + html.Hr(style={"margin": "20px 0", "border": "none", "borderTop": "1px solid #eee"}), + + html.Button("Download POSCAR and Spectrum", id="download_btn", style={ + **button_primary_style, + "width": "100%", + "padding": "12px", + "fontSize": "12px", + "marginRight": "0", + "borderRadius": "6px" + }), + dcc.Download(id="download_sink"), + + # Matching Results Section + html.Div([ + html.Div([ + html.Span("Structure Matching Scores", style={ + "fontWeight": "600", + "fontSize": "13px", + "color": "#333", + }), + html.Button("Clear All", id="clear_scores_btn", style={ + "fontSize": "10px", + "padding": "4px 10px", + "border": "1px solid #ddd", + "borderRadius": "4px", + "backgroundColor": "white", + "color": "#666", + "cursor": "pointer", + "marginLeft": "10px" + }), + ], style={ + "display": "flex", + "alignItems": "center", + "justifyContent": "space-between", + "marginTop": "20px", + "marginBottom": "12px", + "paddingBottom": "10px", + "borderBottom": "1px solid #eee" + }), + html.Div(id='matching_results_table', children=[ + html.Div("Upload experimental spectrum and load structures to see matching scores", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}) + ]), + structure_scores_store, + comparison_range_store, + selected_spectra_store, + sort_metric_store, + ]), + + ], style={ + "backgroundColor": "white", + "borderRadius": "8px", + "padding": "18px", + "border": "1px solid #e8e8e8" + }) + ]), + style={"flex": "1", "minWidth": "400px", "padding": "0 6px"} + ), ], desktop_only=False, - centered=False + centered=False), +], style={ + "backgroundColor": "#f5f5f5", + "minHeight": "100vh", + "padding": "12px", + "fontFamily": base_font +}) + +# Store for energy shift value +energy_shift_store = dcc.Store(id='energy_shift_store', data=0) + + +def parse_file_columns(contents, filename): + """ + Parse uploaded file and extract all columns with their data. + Supports XDI format with # Column.N: name headers. + """ + if contents is None: + return None + + content_type, content_string = contents.split(',') + decoded = b64decode(content_string) + + try: + if filename is None: + filename = "unknown.dat" + + ext = pathlib.Path(filename).suffix.lower() + print(f"=== DEBUG: Parsing file '{filename}' with extension '{ext}'") + + columns = [] + data = [] + + auto_x_col = 0 + auto_y_col = 1 + + if ext in ['.csv', '.dat', '.txt', '.xdi']: + text = decoded.decode('utf-8').replace('\r\n', '\n').replace('\r', '\n') + lines = [line.strip() for line in text.strip().split('\n') if line.strip()] + + comment_lines = [] + data_lines = [] + + for line in lines: + if line.startswith(('#', '%', '!')): + comment_lines.append(line) + else: + data_lines.append(line) + + if len(data_lines) == 0: + raise ValueError("No data lines found in file") + + xdi_columns = {} + energy_col_candidates = [] + absorption_col_candidates = [] + + for comment in comment_lines: + xdi_match = re.match(r'#\s*Column\.(\d+):\s*(.+)', comment, re.IGNORECASE) + if xdi_match: + col_num = int(xdi_match.group(1)) - 1 + col_name = xdi_match.group(2).strip() + xdi_columns[col_num] = col_name + print(f"=== DEBUG: Found XDI column {col_num}: '{col_name}'") + + col_lower = col_name.lower() + if any(term in col_lower for term in ['energy', ' e ', 'ev', 'photon']): + energy_col_candidates.append(col_num) + + if any(term in col_lower for term in ['norm', 'absorption', 'abs', 'mu', 'flat']): + absorption_col_candidates.append(col_num) + + if comment_lines and not xdi_columns: + last_comment = comment_lines[-1] + header_text = last_comment.lstrip('#').strip() + header_parts = header_text.split() + + if len(header_parts) >= 2 and ':' not in header_text: + print(f"=== DEBUG: Found inline header: {header_parts}") + for i, name in enumerate(header_parts): + xdi_columns[i] = name + name_lower = name.lower() + if name_lower in ['e', 'energy', 'ev']: + energy_col_candidates.append(i) + if name_lower in ['norm', 'flat', 'abs', 'mu', 'absorption']: + absorption_col_candidates.append(i) + + first_line = data_lines[0] + + if ',' in first_line: + delimiter = ',' + else: + delimiter = None + + first_parts = first_line.split(delimiter) if delimiter else first_line.split() + num_columns = len(first_parts) + + try: + float(first_parts[0].strip()) + header = None + start_idx = 0 + except ValueError: + header = [p.strip() for p in first_parts] + start_idx = 1 + if not xdi_columns: + for i, name in enumerate(header): + xdi_columns[i] = name + + data = [[] for _ in range(num_columns)] + + for line in data_lines[start_idx:]: + parts = line.split(delimiter) if delimiter else line.split() + for i, part in enumerate(parts): + if i < num_columns: + try: + data[i].append(float(part.strip())) + except ValueError: + pass + + for i in range(num_columns): + if i in xdi_columns: + col_name = xdi_columns[i] + elif header and i < len(header): + col_name = header[i] + else: + col_name = f"Column {i+1}" + + sample_values = data[i][:5] if len(data[i]) >= 5 else data[i] + columns.append({ + 'index': i, + 'name': col_name, + 'num_values': len(data[i]), + 'sample_values': sample_values + }) + + if energy_col_candidates: + auto_x_col = energy_col_candidates[0] + + if absorption_col_candidates: + for candidate in absorption_col_candidates: + col_name = xdi_columns.get(candidate, '').lower() + if 'norm' in col_name or 'flat' in col_name: + auto_y_col = candidate + break + else: + auto_y_col = absorption_col_candidates[0] + elif len(columns) > 1: + auto_y_col = 1 + + elif ext == '.mat': + try: + from scipy.io import loadmat + mat_data = loadmat(io.BytesIO(decoded)) + + data_keys = [k for k in mat_data.keys() if not k.startswith('__')] + + for i, key in enumerate(data_keys): + arr = mat_data[key] + if isinstance(arr, np.ndarray) and arr.size > 1: + flat_arr = arr.flatten().astype(float).tolist() + sample_values = flat_arr[:5] if len(flat_arr) >= 5 else flat_arr + columns.append({ + 'index': i, + 'name': key, + 'num_values': len(flat_arr), + 'sample_values': sample_values + }) + data.append(flat_arr) + + key_lower = key.lower() + if any(term in key_lower for term in ['energy', 'e', 'ev']): + auto_x_col = i + if any(term in key_lower for term in ['absorption', 'abs', 'mu', 'norm']): + auto_y_col = i + + except ImportError: + raise ValueError("scipy is required to read .mat files") + + else: + raise ValueError(f"Unsupported file format: {ext}") + + if len(columns) < 2: + raise ValueError("File must have at least 2 columns for X and Y axes") + + auto_x_col = min(auto_x_col, len(columns) - 1) + auto_y_col = min(auto_y_col, len(columns) - 1) + + if auto_x_col == auto_y_col and len(columns) > 1: + auto_y_col = 1 if auto_x_col == 0 else 0 + + print(f"=== DEBUG: Found {len(columns)} columns") + for col in columns: + print(f" Column {col['index']}: {col['name']} ({col['num_values']} values)") + print(f"=== DEBUG: Auto-selected X={auto_x_col}, Y={auto_y_col}") + + return { + 'columns': columns, + 'data': data, + 'filename': filename, + 'auto_x_col': auto_x_col, + 'auto_y_col': auto_y_col + } + + except Exception as e: + print(f"Error parsing file columns: {e}") + import traceback + traceback.print_exc() + return {'error': str(e)} + + +@app.callback( + Output('exp_raw_data_store', 'data'), + Output('exp_columns_store', 'data'), + Output('exp_x_axis_dropdown', 'options'), + Output('exp_y_axis_dropdown', 'options'), + Output('exp_x_axis_dropdown', 'value'), + Output('exp_y_axis_dropdown', 'value'), + Output('exp_column_selection_area', 'style'), + Output('exp_column_definition_area', 'children'), + Output('exp_file_info', 'children', allow_duplicate=True), + Output('exp_spectrum_upload', 'contents'), + Output('exp_spectrum_upload', 'filename'), + Output('exp_material_name', 'value'), + Input('exp_spectrum_upload', 'contents'), + Input('clear_exp_btn', 'n_clicks'), + State('exp_spectrum_upload', 'filename'), + prevent_initial_call=True +) +def handle_file_upload(contents, clear_clicks, filename): + """Handle file upload - parse columns and populate dropdowns.""" + ctx = dash.callback_context + + if not ctx.triggered: + raise PreventUpdate + + trigger_id = ctx.triggered[0]['prop_id'].split('.')[0] + + hidden_style = {"display": "none"} + visible_style = {"display": "block"} + + if trigger_id == 'clear_exp_btn': + return (None, None, [], [], None, None, hidden_style, [], + 'No experimental spectrum loaded', None, None, '') + + if contents is None: + raise PreventUpdate + + result = parse_file_columns(contents, filename) + + if result is None or 'error' in result: + error_msg = result.get('error', 'Failed to parse file') if result else 'Failed to parse file' + return (None, None, [], [], None, None, hidden_style, [], + html.Span(f"Error: {error_msg}", style={'color': 'red'}), + dash.no_update, dash.no_update, dash.no_update) + + columns = result['columns'] + options = [{'label': f"{col['name']} ({col['num_values']} pts)", 'value': col['index']} for col in columns] + + default_x = result.get('auto_x_col', 0) + default_y = result.get('auto_y_col', 1 if len(columns) > 1 else 0) + + max_visible_rows = 5 + table_height = "auto" if len(columns) <= max_visible_rows else f"{max_visible_rows * 40 + 30}px" + + col_definition = html.Div([ + html.Div(f"Detected {len(columns)} columns (edit names if needed):", + style={"fontSize": "12px", "marginBottom": "6px", "marginTop": "10px"}), + html.Div([ + html.Table([ + html.Thead(html.Tr([ + html.Th("#", style={"padding": "4px 8px", "fontSize": "11px", "width": "30px", "position": "sticky", "top": "0", "backgroundColor": "#fafafa", "zIndex": "1"}), + html.Th("Column Name", style={"padding": "4px 8px", "fontSize": "11px", "position": "sticky", "top": "0", "backgroundColor": "#fafafa", "zIndex": "1"}), + html.Th("Points", style={"padding": "4px 8px", "fontSize": "11px", "width": "50px", "position": "sticky", "top": "0", "backgroundColor": "#fafafa", "zIndex": "1"}), + html.Th("Sample Values", style={"padding": "4px 8px", "fontSize": "11px", "position": "sticky", "top": "0", "backgroundColor": "#fafafa", "zIndex": "1"}), + ])), + html.Tbody([ + html.Tr([ + html.Td(col['index'] + 1, style={"padding": "4px 8px", "fontSize": "11px", "verticalAlign": "middle"}), + html.Td( + dcc.Input( + id={'type': 'col-name-input', 'index': col['index']}, + type='text', + value=col['name'], + style={ + 'width': '100%', + 'padding': '4px', + 'fontSize': '11px', + 'border': '1px solid #ccc', + 'borderRadius': '3px' + } + ), + style={"padding": "4px"} + ), + html.Td(col['num_values'], style={"padding": "4px 8px", "fontSize": "11px", "verticalAlign": "middle"}), + html.Td( + ", ".join([f"{v:.2f}" for v in col['sample_values'][:3]]) + "...", + style={"padding": "4px 8px", "fontSize": "10px", "color": "#666", "verticalAlign": "middle"} + ), + ]) for col in columns + ]) + ], style={"borderCollapse": "collapse", "width": "100%"}) + ], style={ + "maxHeight": table_height, + "overflowY": "auto" if len(columns) > max_visible_rows else "visible", + "border": "1px solid #ddd", + "marginBottom": "10px" + }), + + html.Button("Update Column Names", id="exp_update_col_names_btn", + style={"fontSize": "11px", "padding": "4px 8px", "marginBottom": "10px"}) + ]) + + x_col_name = columns[default_x]['name'] if default_x < len(columns) else "Column 1" + y_col_name = columns[default_y]['name'] if default_y < len(columns) else "Column 2" + info_text = f"File loaded: {filename} (auto-selected: X={x_col_name}, Y={y_col_name})" + + material_name_from_file = pathlib.Path(filename).stem if filename else "" + + return (result, columns, options, options, default_x, default_y, visible_style, col_definition, + html.Span(info_text, style={'color': 'blue'}), + dash.no_update, dash.no_update, material_name_from_file) + + +@app.callback( + Output('exp_columns_store', 'data', allow_duplicate=True), + Output('exp_x_axis_dropdown', 'options', allow_duplicate=True), + Output('exp_y_axis_dropdown', 'options', allow_duplicate=True), + Output('exp_file_info', 'children', allow_duplicate=True), + Input('exp_update_col_names_btn', 'n_clicks'), + State({'type': 'col-name-input', 'index': ALL}, 'value'), + State('exp_columns_store', 'data'), + prevent_initial_call=True +) +def update_column_names(n_clicks, new_names, columns): + """Update column names when user edits them.""" + if n_clicks is None or columns is None: + raise PreventUpdate + + for i, new_name in enumerate(new_names): + if i < len(columns): + columns[i]['name'] = new_name.strip() if new_name else f"Column {i+1}" + + options = [{'label': f"{col['name']} ({col['num_values']} pts)", 'value': col['index']} for col in columns] + + return columns, options, options, html.Span("Column names updated!", style={'color': 'green'}) + + +@app.callback( + Output('exp_spectrum_store', 'data'), + Output('exp_file_info', 'children', allow_duplicate=True), + Input('exp_apply_btn', 'n_clicks'), + State('exp_raw_data_store', 'data'), + State('exp_columns_store', 'data'), + State('exp_x_axis_dropdown', 'value'), + State('exp_y_axis_dropdown', 'value'), + State('exp_material_name', 'value'), + prevent_initial_call=True ) +def apply_column_selection(n_clicks, raw_data, columns, x_col_idx, y_col_idx, material_name): + """Apply column selection and create the spectrum data for plotting.""" + if n_clicks is None or raw_data is None: + raise PreventUpdate + + if x_col_idx is None or y_col_idx is None: + return None, html.Span("Please select both X and Y axis columns", style={'color': 'red'}) + + try: + data = raw_data['data'] + filename = raw_data['filename'] + + x_data = np.array(data[x_col_idx]) + y_data = np.array(data[y_col_idx]) + + min_len = min(len(x_data), len(y_data)) + x_data = x_data[:min_len] + y_data = y_data[:min_len] + + if len(x_data) < 2: + return None, html.Span("Not enough data points", style={'color': 'red'}) + + sort_idx = np.argsort(x_data) + x_data = x_data[sort_idx] + y_data = y_data[sort_idx] + + x_label = columns[x_col_idx]['name'] + y_label = columns[y_col_idx]['name'] + + display_name = material_name if material_name and material_name.strip() else filename + + result = { + 'energy': x_data.tolist(), + 'absorption': y_data.tolist(), + 'filename': filename, + 'material_name': display_name, + 'x_label': x_label, + 'y_label': y_label + } + + x_min, x_max = x_data.min(), x_data.max() + info_text = f"✓ {display_name} ({len(x_data)} points, {x_label}: {x_min:.1f}-{x_max:.1f})" + + return result, html.Span(info_text, style={'color': 'green'}) + + except Exception as e: + print(f"Error applying column selection: {e}") + return None, html.Span(f"Error: {str(e)}", style={'color': 'red'}) @app.callback( @@ -75,12 +1005,8 @@ def download_xas_prediction(n_clicks, st_data, el_type): el, theory = el_type.split(' ') st = Structure.from_dict(st_data) d_xas = st_data['xas'] - specs_list = [ene_grid[el]] + list(d_xas.values()) - avg_spec = np.stack(specs_list[1:]).mean(axis=0) - specs_list.append(avg_spec) - specs = np.stack(specs_list) - - site_idxs = ["Energy"] + [f'Atom #{int(i) + 1}' for i in d_xas.keys()] + ['Average'] + specs = np.stack([ene_grid[el]] + list(d_xas.values())) + site_idxs = ["Energy"] + [f'Atom #{int(i) + 1}' for i in d_xas.keys()] df = pd.DataFrame(specs, index=site_idxs) with tempfile.TemporaryDirectory() as td: tmpdir = pathlib.Path(td) @@ -93,8 +1019,7 @@ def download_xas_prediction(n_clicks, st_data, el_type): st.to(fn_poscar, fmt='poscar') df.to_csv(fn_spec, float_format="%.3f", header=False) zip_fn = tmpdir / f'OmniXAS_{el}_{theory}_Prediction_{n_clicks}.zip' - with ZipFile(zip_fn, - mode="w") as zip_file: + with ZipFile(zip_fn, mode="w") as zip_file: for fn in files_to_zip: zip_file.write(fn, arcname=fn.name) bytes = b64encode((tmpdir / zip_fn).read_bytes()).decode("ascii") @@ -108,24 +1033,21 @@ def download_xas_prediction(n_clicks, st_data, el_type): @app.callback( Output(struct_component.id(), "data", allow_duplicate=True), - Output(upload_component.id("upload_data"), "contents"), Output('st_source', "children", allow_duplicate=True), Input(search_component.id(), "data"), State('absorber', 'value') ) -def update_structure_by_mpid(search_mpid: str, el_type) -> Structure: # pyright: ignore[reportRedeclaration] +def update_structure_by_mpid(search_mpid: str, el_type) -> Structure: if not search_mpid: raise PreventUpdate with MPRester() as mpr: st = mpr.get_structure_by_material_id(search_mpid) if not isinstance(st, Structure): - raise Exception("mp_api MPRester.get_structure_by_material_id did not return a pymatgen \"Structure\" object. This has been observed to occur when using an outdated version of mp_api with a more recent version of emmet-core. For now, please use the versions specified by LightShow's pyproject.toml") - - print("Struct from material.") - + raise Exception("mp_api MPRester.get_structure_by_material_id did not return a pymatgen Structure object.") + st_dict = decorate_structure_with_xas(st, el_type) - return st_dict, None, f"Current structure: {search_mpid}" # pyright: ignore[reportReturnType] + return st_dict, f"Current structure: {search_mpid}" def decorate_structure_with_xas(st: Structure, el_type): @@ -139,55 +1061,419 @@ def decorate_structure_with_xas(st: Structure, el_type): return st_dict -@app.callback( - Output(struct_component.id(), "data", allow_duplicate=True), - Output('st_source', "children", allow_duplicate=True), - Input(upload_component.id(), "data"), - State(upload_component.id('upload_data'), 'filename'), - State('absorber', 'value') -) -def update_structure_by_file(upload_data: dict, fn, el_type) -> Structure: - if not upload_data: - raise PreventUpdate - st = Structure.from_dict(upload_data['data']) - st_dict = decorate_structure_with_xas(st, el_type) - return st_dict, f"Current structure: {fn}" # pyright: ignore[reportReturnType] +def parse_structure_file(contents, filename): + """ + Parse a structure file from base64-encoded contents. + Supports CIF, VASP/POSCAR, and JSON formats. + """ + try: + content_type, content_string = contents.split(',') + decoded = b64decode(content_string) + + ext = pathlib.Path(filename).suffix.lower() + + if ext in ['.cif']: + # CIF format + from pymatgen.io.cif import CifParser + text = decoded.decode('utf-8') + parser = CifParser.from_str(text) + st = parser.parse_structures()[0] + elif ext in ['.vasp', '.poscar', '']: + # VASP/POSCAR format + from pymatgen.io.vasp import Poscar + text = decoded.decode('utf-8') + poscar = Poscar.from_str(text) + st = poscar.structure + elif ext == '.json': + # JSON format (pymatgen Structure dict) + import json + text = decoded.decode('utf-8') + data = json.loads(text) + st = Structure.from_dict(data) + else: + # Try to auto-detect format + text = decoded.decode('utf-8') + try: + # Try CIF first + from pymatgen.io.cif import CifParser + parser = CifParser.from_str(text) + st = parser.parse_structures()[0] + except: + try: + # Try POSCAR + from pymatgen.io.vasp import Poscar + poscar = Poscar.from_str(text) + st = poscar.structure + except: + raise ValueError(f"Could not parse file format: {ext}") + + return st + except Exception as e: + print(f"Error parsing structure file {filename}: {e}") + import traceback + traceback.print_exc() + return None @app.callback( - Output("xas_plot", "figure", allow_duplicate=True), - Input(struct_component.id(), "data"), - State('absorber', 'value') + Output('structure_scores_store', 'data', allow_duplicate=True), + Output('matching_results_table', 'children', allow_duplicate=True), + Output('comparison_range_store', 'data', allow_duplicate=True), + Output('batch_status', 'children'), + Output('batch_structure_upload', 'contents'), + Output(struct_component.id(), "data", allow_duplicate=True), + Output('st_source', "children", allow_duplicate=True), + Input('batch_structure_upload', 'contents'), + State('batch_structure_upload', 'filename'), + State('exp_spectrum_store', 'data'), + State('absorber', 'value'), + State('structure_scores_store', 'data'), + State('sort_metric_store', 'data'), + prevent_initial_call=True ) -def predict_average_xas(st_data: dict, el_type) -> Structure: - if not st_data: +def handle_batch_upload(contents_list, filenames_list, exp_data, el_type, existing_scores, sort_metric): + """ + Handle batch upload of multiple structure files. + Parse each file, generate XAS spectrum, and compare with experimental data. + """ + if contents_list is None or len(contents_list) == 0: raise PreventUpdate - specs = st_data['xas'] - if len(specs) == 0: - fig = build_figure(None, el_type, is_average=True, no_element=True, sel_mismatch=False) + + if existing_scores is None: + existing_scores = [] + + if sort_metric is None: + sort_metric = 'coss_deriv' + + has_exp_data = exp_data is not None and 'energy' in exp_data and 'absorption' in exp_data + + element = el_type.split(' ')[0] + + # Process each uploaded file + successful = 0 + failed = 0 + failed_files = [] + last_st_dict = None + last_filename = None + comparison_range = None + + for contents, filename in zip(contents_list, filenames_list): + try: + # Parse the structure file + st = parse_structure_file(contents, filename) + + if st is None: + failed += 1 + failed_files.append(filename) + continue + + # Check if structure contains the absorbing element + if element not in st.composition: + print(f"Structure {filename} does not contain {element}, skipping...") + failed += 1 + failed_files.append(f"{filename} (no {element})") + continue + + # Generate XAS spectrum + specs = predict(st, element, el_type.split(' ')[1]) + + if len(specs) == 0: + failed += 1 + failed_files.append(f"{filename} (no spectrum)") + continue + + # Calculate average spectrum + specs_array = np.array(list(specs.values())) + predicted_spectrum = specs_array.mean(axis=0) + energy = ene_grid[element].tolist() + + # Get structure ID from filename (remove extension) + structure_id = pathlib.Path(filename).stem + + # Compare with experimental data if available + + if has_exp_data: + match_result = get_spectrum_match_score(predicted_spectrum, exp_data, element) + else: + match_result = { + 'score': 0.0, + 'correlations': {}, + 'shift': 0.0, + 'comparison_range': None + } + + # Check if this structure already exists - preserve selection state + old_entry = next((s for s in existing_scores if s['structure_id'] == structure_id), None) + was_selected = old_entry.get('selected', False) if old_entry else False + + # Remove old entry if exists + existing_scores = [s for s in existing_scores if s['structure_id'] != structure_id] + + # Add new score entry + existing_scores.append({ + 'structure_id': structure_id, + 'score': match_result['score'], + 'shift': match_result['shift'], + 'correlations': match_result['correlations'], + 'comparison_range': match_result['comparison_range'], + 'spectrum': predicted_spectrum.tolist(), + 'energy': energy, + 'element': element, + 'selected': was_selected + }) + + # Keep track of comparison range from last successful processing + if match_result['comparison_range'] is not None: + comparison_range = match_result['comparison_range'] + + # Store last structure for display + st_dict = st.as_dict() + st_dict['xas'] = specs + last_st_dict = st_dict + last_filename = filename + + successful += 1 + + except Exception as e: + print(f"Error processing {filename}: {e}") + import traceback + traceback.print_exc() + failed += 1 + failed_files.append(filename) + + # Sort scores by current metric + existing_scores = sort_scores_by_metric(existing_scores, sort_metric) + + # Build status message + if successful > 0 and failed == 0: + status_msg = html.Span(f"✓ Processed {successful} structure(s) successfully", style={'color': 'green'}) + elif successful > 0 and failed > 0: + status_msg = html.Span([ + html.Span(f"✓ Processed {successful} structure(s). ", style={'color': 'green'}), + html.Span(f"✗ Failed: {failed} ({', '.join(failed_files[:3])}{'...' if len(failed_files) > 3 else ''})", style={'color': 'orange'}) + ]) else: - specs = np.array(list(specs.values())) - spectrum = specs.mean(axis=0) - fig = build_figure(spectrum, el_type, is_average=True, no_element=False, sel_mismatch=False) - return fig # pyright: ignore[reportReturnType] + status_msg = html.Span(f"✗ Failed to process all {failed} file(s)", style={'color': 'red'}) + + # Update source text + if successful == 1: + source_text = f"Current structure: {last_filename}" + elif successful > 1: + source_text = f"Batch loaded: {successful} structures" + else: + source_text = "No structures loaded" + + return ( + existing_scores, + build_scores_table(existing_scores, sort_metric), + comparison_range, + status_msg, + None, # Clear the upload contents + last_st_dict if last_st_dict else dash.no_update, + source_text + ) -def build_figure(spectrum, el_type, is_average, no_element, sel_mismatch): +def build_figure_with_exp(predicted_spectrum, exp_data, el_type, is_average, no_element, sel_mismatch, energy_shift=0, comparison_range=None, selected_spectra=None, current_structure_id=None): + """ + Build a plotly figure with predicted spectrum and optional experimental overlay. + The comparison_range parameter zooms the plot to the energy range used for comparison. + """ element = el_type.split(" ")[0] - if spectrum is None: - ene = None - else: - ene = ene_grid[element] - if no_element: + fig = go.Figure() + + has_exp_data = exp_data is not None and 'energy' in exp_data and 'absorption' in exp_data + has_selected = selected_spectra is not None and len(selected_spectra) > 0 + + if has_selected: + num_selected = len(selected_spectra) + title = f'Comparing {num_selected} Structure{"s" if num_selected > 1 else ""} with Experimental' + elif predicted_spectrum is None and has_exp_data: + exp_display_name = exp_data.get('material_name', exp_data.get('filename', 'Experimental')) + title = f'Experimental Spectrum: {exp_display_name}' + elif no_element: title = f"This structure doesn't contain {element}" elif sel_mismatch: title = f"The selected atom is not a {element} atom" elif is_average: title = f'Average K-edge XANES Spectrum of {el_type}' + if has_exp_data: + title += " (with Experimental)" else: title = f'K-edge XANES Spectrum for the selected {element} atom' - fig = px.scatter(x=ene, y=spectrum, title=title, - labels={'x': "Energy (eV)", "y": "Absorption"}) + if has_exp_data: + title += " (with Experimental)" + + exp_energy = None + exp_absorption = None + if has_exp_data: + exp_energy = np.array(exp_data['energy']) + exp_absorption = np.array(exp_data['absorption']) + + colors = ['#636EFA', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52'] + + if has_selected: + for idx, spec_entry in enumerate(selected_spectra): + spec_data = np.array(spec_entry['spectrum']) + spec_energy = np.array(spec_entry['energy']) + spec_shift = spec_entry.get('shift', 0.0) + structure_id = spec_entry['structure_id'] + + spec_energy_shifted = spec_energy + spec_shift + + if has_exp_data and len(exp_absorption) > 0: + pred_range = np.max(spec_data) - np.min(spec_data) + exp_range = np.max(exp_absorption) - np.min(exp_absorption) + + if pred_range > 0 and exp_range > 0: + spec_normalized = (spec_data - np.min(spec_data)) / pred_range + spec_scaled = spec_normalized * exp_range + np.min(exp_absorption) + else: + spec_scaled = spec_data + else: + spec_scaled = spec_data + + color = colors[idx % len(colors)] + fig.add_trace(go.Scatter( + x=spec_energy_shifted, + y=spec_scaled, + mode='lines', + name=f'{structure_id}', + line=dict(color=color, width=2), + )) + + elif predicted_spectrum is not None: + ene = ene_grid[element] + ene_shifted = ene + energy_shift + + predicted_was_normalized = False + if has_exp_data and len(exp_absorption) > 0: + pred_range = np.max(predicted_spectrum) - np.min(predicted_spectrum) + exp_range = np.max(exp_absorption) - np.min(exp_absorption) + + if pred_range > 0 and exp_range > 0: + pred_normalized = (predicted_spectrum - np.min(predicted_spectrum)) / pred_range + pred_scaled = pred_normalized * exp_range + np.min(exp_absorption) + predicted_was_normalized = True + else: + pred_scaled = predicted_spectrum + else: + pred_scaled = predicted_spectrum + + if current_structure_id: + pred_name = f'{current_structure_id}' + if predicted_was_normalized: + pred_name += ' (normalized)' + else: + pred_name = 'Predicted (normalized)' if predicted_was_normalized else 'Predicted' + + if energy_shift != 0: + pred_name += f' [{energy_shift:+.1f} eV]' + + fig.add_trace(go.Scatter( + x=ene_shifted, + y=pred_scaled, + mode='lines', + name=pred_name, + line=dict(color='#636EFA', width=2), + )) + + if has_exp_data: + exp_display_name = exp_data.get('material_name', exp_data.get('filename', 'Experimental')) + fig.add_trace(go.Scatter( + x=exp_energy, + y=exp_absorption, + mode='markers', + name=f'Exp: {exp_display_name}', + marker=dict(color='#EF553B', size=4), + )) + + if has_exp_data: + x_axis_label = exp_data.get('x_label', 'Energy (eV)') + y_axis_label = exp_data.get('y_label', 'Absorption') + else: + x_axis_label = "Energy (eV)" + y_axis_label = "Absorption" + + layout_config = dict( + title=title, + xaxis_title=x_axis_label, + yaxis_title=y_axis_label, + legend=dict( + yanchor="bottom", + y=0.01, + xanchor="right", + x=0.99, + bgcolor="rgba(0,0,0,0)", + bordercolor="rgba(0,0,0,0)", + font=dict(size=10) + ), + hovermode='x unified' + ) + + # Apply comparison range to x-axis to zoom into the comparison region + # Only apply if we have both experimental data and a valid comparison range + if has_exp_data and comparison_range is not None and len(comparison_range) == 2: + x_start, x_end = comparison_range + # Validate the range makes sense + if x_start < x_end and x_end - x_start > 5: # At least 5 eV range + # Add 10% padding on each side for better visualization + range_width = x_end - x_start + padding = range_width * 0.1 + layout_config['xaxis'] = dict( + range=[x_start - padding, x_end + padding], + title=x_axis_label + ) + print(f"=== Plot x-axis range set to: {x_start - padding:.1f} - {x_end + padding:.1f} eV ===") + + fig.update_layout(**layout_config) + return fig + + +@app.callback( + Output("xas_plot", "figure", allow_duplicate=True), + Input(struct_component.id(), "data"), + Input('exp_spectrum_store', 'data'), + Input('energy_shift_slider', 'value'), + Input('comparison_range_store', 'data'), + Input('structure_scores_store', 'data'), + State('absorber', 'value'), + State('st_source', 'children') +) +def predict_average_xas(st_data: dict, exp_data: dict, energy_shift: float, comparison_range, structure_scores, el_type, structure_source) -> Structure: + if st_data is None and exp_data is None: + raise PreventUpdate + + current_structure_id = None + if structure_source and isinstance(structure_source, str): + if ":" in structure_source: + current_structure_id = structure_source.split(":")[-1].strip() + else: + current_structure_id = structure_source + + selected_spectra = None + if structure_scores: + selected_spectra = [s for s in structure_scores if s.get('selected', False) and 'spectrum' in s] + if len(selected_spectra) == 0: + selected_spectra = None + + predicted_spectrum = None + no_element = False + + if selected_spectra is None and st_data is not None: + specs = st_data.get('xas', {}) + if len(specs) == 0: + no_element = True + else: + specs_array = np.array(list(specs.values())) + predicted_spectrum = specs_array.mean(axis=0) + + fig = build_figure_with_exp( + predicted_spectrum, exp_data, el_type, + is_average=True, no_element=no_element, sel_mismatch=False, + energy_shift=energy_shift or 0, comparison_range=comparison_range, + selected_spectra=selected_spectra, current_structure_id=current_structure_id + ) return fig @@ -195,33 +1481,46 @@ def build_figure(spectrum, el_type, is_average, no_element, sel_mismatch): Output("xas_plot", "figure", allow_duplicate=True), Input(struct_component.id('scene'), "selectedObject"), State(struct_component.id(), 'data'), - State('absorber', 'value') + State('exp_spectrum_store', 'data'), + State('absorber', 'value'), + State('energy_shift_slider', 'value'), + State('comparison_range_store', 'data'), + State('st_source', 'children') ) -def predict_site_specific_xas(sel, st_data, el_type) -> Structure: +def predict_site_specific_xas(sel, st_data, exp_data, el_type, energy_shift, comparison_range, structure_source) -> Structure: + if st_data is None: + raise PreventUpdate + + current_structure_id = None + if structure_source and isinstance(structure_source, str): + if ":" in structure_source: + current_structure_id = structure_source.split(":")[-1].strip() + else: + current_structure_id = structure_source + specs = st_data['xas'] element = el_type.split(' ')[0] + shift = energy_shift or 0 if len(specs) == 0: - fig = build_figure(None, el_type, is_average=False, no_element=True, sel_mismatch=False) - elif len(sel) == 0: + fig = build_figure_with_exp(None, exp_data, el_type, is_average=False, no_element=True, sel_mismatch=False, energy_shift=shift, comparison_range=comparison_range, current_structure_id=current_structure_id) + elif sel is None or len(sel) == 0: specs = np.array(list(specs.values())) spectrum = specs.mean(axis=0) - fig = build_figure(spectrum, el_type, is_average=True, no_element=False, sel_mismatch=False) + fig = build_figure_with_exp(spectrum, exp_data, el_type, is_average=True, no_element=False, sel_mismatch=False, energy_shift=shift, comparison_range=comparison_range, current_structure_id=current_structure_id) else: st = Structure.from_dict(st_data) - el_sel = sel[0]['tooltip'].split('(')[0].strip() - pos_sel = np.array([float(x) for x in sel[0]['tooltip'].split('(')[1].split(')')[0].split(',')]) - frac_pos_sel = st.lattice.get_fractional_coords(pos_sel) - dist = st.lattice.get_all_distances(frac_pos_sel, st.frac_coords) - dist = dist[0] - i_site = np.argmin(dist) - assert dist[i_site] < 0.01 - assert st[i_site].specie.symbol == el_sel # pyright: ignore[reportArgumentType] - if st[i_site].specie.symbol != element: # pyright: ignore[reportArgumentType] - fig = build_figure(None, el_type, is_average=False, no_element=False, sel_mismatch=True) + spheres = st._get_sites_to_draw() + spheres = list(spheres) + i_sphere = int(sel[0]['id'].split('--')[-1]) + cur_sphere = spheres[i_sphere] + i_site = cur_sphere[0] + if st[i_site].specie.symbol != element: + fig = build_figure_with_exp(None, exp_data, el_type, is_average=False, no_element=False, sel_mismatch=True, energy_shift=shift, comparison_range=comparison_range, current_structure_id=current_structure_id) else: spectrum = np.array(specs[str(i_site)]) - fig = build_figure(spectrum, el_type, is_average=False, no_element=False, sel_mismatch=False) - return fig # pyright: ignore[reportReturnType] + site_structure_id = f"{current_structure_id} (site {i_site})" if current_structure_id else None + fig = build_figure_with_exp(spectrum, exp_data, el_type, is_average=False, no_element=False, sel_mismatch=False, energy_shift=shift, comparison_range=comparison_range, current_structure_id=site_structure_id) + return fig @app.callback( @@ -229,10 +1528,333 @@ def predict_site_specific_xas(sel, st_data, el_type) -> Structure: Input('absorber', 'value'), State(struct_component.id(), "data") ) -def update_structure_by_mpid(el_type, st_data) -> Structure: +def update_structure_by_absorber(el_type, st_data) -> Structure: + if st_data is None: + raise PreventUpdate st = Structure.from_dict(st_data) st_dict = decorate_structure_with_xas(st, el_type) - return st_dict # pyright: ignore[reportReturnType] + return st_dict + + +@app.callback( + Output('energy_shift_slider', 'value'), + Input('reset_shift_btn', 'n_clicks'), + prevent_initial_call=True +) +def reset_energy_shift(n_clicks): + if n_clicks is None: + raise PreventUpdate + return 0 + + +@app.callback( + Output('energy_shift_display', 'children'), + Input('energy_shift_slider', 'value') +) +def update_shift_display(value): + if value is None: + value = 0 + return f"{value:+.1f} eV" + + +@app.callback( + Output('sort_metric_store', 'data'), + Input({'type': 'sort-metric-btn', 'metric': ALL}, 'n_clicks'), + State('sort_metric_store', 'data'), + prevent_initial_call=True +) +def handle_sort_click(n_clicks_list, current_sort_metric): + """Handle clicks on sortable column headers to change the sort metric.""" + ctx = dash.callback_context + if not ctx.triggered: + raise PreventUpdate + + trigger_id = ctx.triggered[0]['prop_id'] + import json + try: + id_str = trigger_id.rsplit('.', 1)[0] + id_dict = json.loads(id_str) + clicked_metric = id_dict['metric'] + except Exception: + raise PreventUpdate + + return clicked_metric + + +@app.callback( + Output('structure_scores_store', 'data'), + Output('matching_results_table', 'children'), + Output('comparison_range_store', 'data'), + Input(struct_component.id(), "data"), + Input('exp_spectrum_store', 'data'), + Input('clear_scores_btn', 'n_clicks'), + Input({'type': 'spectrum-checkbox', 'index': ALL}, 'value'), + Input('sort_metric_store', 'data'), + State('structure_scores_store', 'data'), + State('st_source', 'children'), + State('absorber', 'value'), + prevent_initial_call=True +) +def update_matching_results(st_data, exp_data, clear_clicks, checkbox_values, sort_metric, existing_scores, structure_source, el_type): + """Update the matching results table when a structure is loaded and experimental data is available.""" + ctx = dash.callback_context + + if not ctx.triggered: + raise PreventUpdate + + trigger_id = ctx.triggered[0]['prop_id'] + + if existing_scores is None: + existing_scores = [] + + if sort_metric is None: + sort_metric = 'coss_deriv' + + if 'clear_scores_btn' in trigger_id: + return [], html.Div("Upload experimental spectrum and load structures to see matching scores", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}), None + + if 'spectrum-checkbox' in trigger_id: + for i, score_entry in enumerate(existing_scores): + if i < len(checkbox_values): + score_entry['selected'] = bool(checkbox_values[i]) + existing_scores = sort_scores_by_metric(existing_scores, sort_metric) + return existing_scores, build_scores_table(existing_scores, sort_metric), dash.no_update + + if 'sort_metric_store' in trigger_id: + existing_scores = sort_scores_by_metric(existing_scores, sort_metric) + return existing_scores, build_scores_table(existing_scores, sort_metric), dash.no_update + + has_exp_data = exp_data is not None and 'energy' in exp_data and 'absorption' in exp_data + + if not has_exp_data: + if len(existing_scores) == 0: + return existing_scores, html.Div("Upload experimental spectrum first to enable matching", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}), None + else: + return existing_scores, build_scores_table(existing_scores, sort_metric), dash.no_update + + if st_data is None: + if len(existing_scores) == 0: + return existing_scores, html.Div("Load a structure to see matching scores", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}), None + else: + return existing_scores, build_scores_table(existing_scores, sort_metric), dash.no_update + + specs = st_data.get('xas', {}) + if len(specs) == 0: + if len(existing_scores) == 0: + return existing_scores, html.Div("No spectrum available for matching", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}), None + else: + return existing_scores, build_scores_table(existing_scores, sort_metric), dash.no_update + + specs_array = np.array(list(specs.values())) + predicted_spectrum = specs_array.mean(axis=0) + element = el_type.split(' ')[0] + energy = ene_grid[element].tolist() + + structure_id = "unknown" + if structure_source and isinstance(structure_source, str): + if ":" in structure_source: + structure_id = structure_source.split(":")[-1].strip() + else: + structure_id = structure_source + + match_result = get_spectrum_match_score(predicted_spectrum, exp_data, element) + + old_entry = next((s for s in existing_scores if s['structure_id'] == structure_id), None) + was_selected = old_entry.get('selected', False) if old_entry else False + + updated_scores = [s for s in existing_scores if s['structure_id'] != structure_id] + + updated_scores.append({ + 'structure_id': structure_id, + 'score': match_result['score'], + 'shift': match_result['shift'], + 'correlations': match_result['correlations'], + 'comparison_range': match_result['comparison_range'], + 'spectrum': predicted_spectrum.tolist(), + 'energy': energy, + 'element': element, + 'selected': was_selected + }) + + updated_scores = sort_scores_by_metric(updated_scores, sort_metric) + return updated_scores, build_scores_table(updated_scores, sort_metric), match_result['comparison_range'] + + +def sort_scores_by_metric(scores, metric): + """Sort scores list by the given metric. For normed_wasserstein, lower is better (sort ascending).""" + if not scores: + return scores + + reverse = metric != 'normed_wasserstein' + + def sort_key(entry): + correlations = entry.get('correlations', {}) + val = correlations.get(metric, 0.0) + if val is None or (isinstance(val, float) and (np.isnan(val) or np.isinf(val))): + return -999 if reverse else 999 + return val + + return sorted(scores, key=sort_key, reverse=reverse) + + +def build_scores_table(scores, sort_metric='coss_deriv'): + """Build the HTML table for displaying structure scores with all metrics as sortable columns.""" + if not scores: + return html.Div("No scores yet", + style={"color": "#999", "fontSize": "12px", "textAlign": "center", "padding": "20px"}) + + base_header_style = { + "padding": "5px 4px", + "textAlign": "right", + "fontWeight": "600", + "fontSize": "10px", + "color": "#666", + "borderBottom": "2px solid #e8e8e8", + "backgroundColor": "#fafafa", + "whiteSpace": "nowrap", + } + + active_header_style = { + **base_header_style, + "color": "#333", + "borderBottom": "2px solid #333", + "backgroundColor": "#f0f0f0", + } + + table_cell_style = { + "padding": "5px 4px", + "fontSize": "11px", + "color": "#333", + "borderBottom": "1px solid #eee", + "textAlign": "right", + } + + header_cells = [ + html.Th("", style={**base_header_style, "width": "28px", "textAlign": "center"}), + html.Th("#", style={**base_header_style, "width": "22px", "textAlign": "center"}), + html.Th("Structure", style={**base_header_style, "textAlign": "left", "minWidth": "70px"}), + html.Th("Shift", style={**base_header_style, "width": "50px"}), + ] + + for metric in ALL_METRICS: + is_active = (metric == sort_metric) + style = active_header_style if is_active else base_header_style + arrow = " ▼" if is_active and metric != 'normed_wasserstein' else (" ▲" if is_active else "") + + header_cells.append( + html.Th( + html.Button( + METRIC_SHORT_NAMES[metric] + arrow, + id={'type': 'sort-metric-btn', 'metric': metric}, + style={ + "border": "none", + "background": "none", + "cursor": "pointer", + "fontWeight": "700" if is_active else "600", + "fontSize": "10px", + "color": "#333" if is_active else "#666", + "padding": "0", + "fontFamily": base_font, + "textDecoration": "none", + "whiteSpace": "nowrap", + }, + title=f"Sort by {metric}" + (" (lower is better)" if metric == 'normed_wasserstein' else " (higher is better)"), + ), + style=style, + ) + ) + + header = html.Tr(header_cells) + + rows = [] + for rank, entry in enumerate(scores): + correlations = entry.get('correlations', {}) + shift = entry.get('shift', 0.0) + is_selected = entry.get('selected', False) + + row_cells = [ + html.Td( + dcc.Checklist( + id={'type': 'spectrum-checkbox', 'index': rank}, + options=[{'label': '', 'value': True}], + value=[True] if is_selected else [], + style={"margin": "0", "padding": "0"}, + inputStyle={"marginRight": "0"} + ), + style={**table_cell_style, "textAlign": "center", "padding": "3px"} + ), + html.Td(rank + 1, style={**table_cell_style, "color": "#999", "fontWeight": "500", "textAlign": "center"}), + html.Td(entry['structure_id'], style={ + **table_cell_style, + "fontFamily": "monospace", + "fontSize": "10px", + "textAlign": "left", + "maxWidth": "90px", + "overflow": "hidden", + "textOverflow": "ellipsis", + "whiteSpace": "nowrap", + }), + html.Td(f"{shift:+.1f}", style={ + **table_cell_style, + "fontSize": "10px", + "color": "#666" + }), + ] + + for metric in ALL_METRICS: + val = correlations.get(metric, None) + is_sort_col = (metric == sort_metric) + + if val is None or (isinstance(val, float) and (np.isnan(val) or np.isinf(val))): + display_val = "—" + score_color = "#999" + else: + display_val = f"{val:.3f}" + if metric == 'normed_wasserstein': + if val <= 0.1: + score_color = "#28a745" + elif val <= 0.3: + score_color = "#ffc107" + else: + score_color = "#dc3545" + else: + if val >= 0.9: + score_color = "#28a745" + elif val >= 0.7: + score_color = "#ffc107" + else: + score_color = "#dc3545" + + cell_style = { + **table_cell_style, + "fontWeight": "700" if is_sort_col else "400", + "color": score_color, + "fontSize": "11px" if is_sort_col else "10px", + "backgroundColor": "#f8f8f8" if is_sort_col else "transparent", + } + + row_cells.append(html.Td(display_val, style=cell_style)) + + rows.append(html.Tr(row_cells)) + + table = html.Table( + [html.Thead(header), html.Tbody(rows)], + style={ + "width": "100%", + "borderCollapse": "collapse", + "fontFamily": base_font, + "tableLayout": "auto", + } + ) + + return html.Div(table, style={ + "overflowX": "auto", + "fontSize": "11px", + }) ctc.register_crystal_toolkit(app=app, layout=onmixas_layout) @@ -244,7 +1866,7 @@ def serve(): "please set your materials project API key to " "this environment variable before running this app") exit() - app.run(debug=False, port=8443, host='127.0.0.1') + app.run(debug=False, port=8443, host='0.0.0.0') if __name__ == "__main__": - serve() + serve() \ No newline at end of file