#!/usr/bin/env python3
# File: ./src/scitex/bridge/_plt_vis.py
# Time-stamp: "2024-12-09 10:00:00 (ywatanabe)"
"""
Bridge module for plt ↔ vis integration.
Provides adapters to:
- Convert scitex.plt figures to vis FigureModel
- Extract tracking data as PlotModel configurations
- Synchronize matplotlib state with vis JSON
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
# Legacy model imports — deprecated module, suppress warnings.
#
# Models historically lived in `scitex_io.bundle.kinds._plot._models`
# but the standalone scitex-io split removed that path; the umbrella
# `scitex.io.bundle.kinds._plot._models` still ships them. Try
# standalone first, then umbrella, before giving up.
def _import_vis_models():
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
for mod_path in (
"scitex_io.bundle.kinds._plot._models",
"scitex.io.bundle.kinds._plot._models",
):
try:
import importlib
m = importlib.import_module(mod_path)
return (
m.AnnotationModel,
m.AxesModel,
m.AxesStyle,
m.FigureModel,
m.GuideModel,
m.PlotModel,
m.PlotStyle,
m.TextStyle,
)
except ImportError:
continue
return None
_models = _import_vis_models()
if _models is not None:
(
AnnotationModel,
AxesModel,
AxesStyle,
FigureModel,
GuideModel,
PlotModel,
PlotStyle,
TextStyle,
) = _models
VIS_MODEL_AVAILABLE = True
else:
FigureModel = AxesModel = PlotModel = AnnotationModel = None
GuideModel = PlotStyle = AxesStyle = TextStyle = None
VIS_MODEL_AVAILABLE = False
[docs]
def axes_to_vis_axes(
ax,
row: int = 0,
col: int = 0,
scitex_ax=None,
include_data: bool = True,
include_style: bool = True,
) -> AxesModel:
"""
Convert a matplotlib axes to a vis AxesModel.
Parameters
----------
ax : matplotlib.axes.Axes
The axes to convert
row : int
Row position in layout
col : int
Column position in layout
scitex_ax : AxisWrapper, optional
Scitex axis wrapper with tracking history
include_data : bool
Whether to include plot data
include_style : bool
Whether to include style information
Returns
-------
AxesModel
The vis axes model
"""
# Get underlying matplotlib axes
mpl_ax = ax._axes_mpl if hasattr(ax, "_axes_mpl") else ax
# Extract axis properties
axes_model = AxesModel(
row=row,
col=col,
xlabel=mpl_ax.get_xlabel() or None,
ylabel=mpl_ax.get_ylabel() or None,
title=mpl_ax.get_title() or None,
xlim=list(mpl_ax.get_xlim()),
ylim=list(mpl_ax.get_ylim()),
xscale=mpl_ax.get_xscale(),
yscale=mpl_ax.get_yscale(),
)
# Extract tick info
xticks = mpl_ax.get_xticks()
yticks = mpl_ax.get_yticks()
if len(xticks) > 0:
axes_model.xticks = [float(t) for t in xticks]
if len(yticks) > 0:
axes_model.yticks = [float(t) for t in yticks]
# Extract style if requested
if include_style:
axes_model.style = _extract_axes_style(mpl_ax)
# Extract plots from tracking history
if include_data and scitex_ax and hasattr(scitex_ax, "history"):
plots = tracking_to_plot_configs(scitex_ax.history)
for plot in plots:
axes_model.plots.append(
plot.to_dict() if hasattr(plot, "to_dict") else plot
)
# Extract annotations
for text_obj in mpl_ax.texts:
annotation = _text_to_annotation(text_obj)
if annotation:
axes_model.annotations.append(annotation.to_dict())
# Extract guides (axhline, axvline, etc.)
guides = _extract_guides(mpl_ax)
for guide in guides:
axes_model.guides.append(guide.to_dict())
return axes_model
[docs]
def tracking_to_plot_configs(
history: Dict[str, Tuple],
) -> List[PlotModel]:
"""
Convert scitex.plt tracking history to PlotModel configurations.
Parameters
----------
history : Dict[str, Tuple]
Tracking history from AxisWrapper
Format: {id: (id, method_name, tracked_dict, kwargs)}
Returns
-------
List[PlotModel]
List of PlotModel configurations
"""
plots = []
for plot_id, (_, method_name, tracked_dict, kwargs) in history.items():
plot_model = _history_entry_to_plot_model(
plot_id, method_name, tracked_dict, kwargs
)
if plot_model:
plots.append(plot_model)
return plots
# =============================================================================
# Helper Functions
# =============================================================================
def _get_mpl_figure(fig):
"""Get the underlying matplotlib figure."""
if hasattr(fig, "_fig_mpl"):
return fig._fig_mpl
return fig
def _get_scitex_axes(fig):
"""Get scitex axes wrappers from figure."""
if hasattr(fig, "_axes_scitex"):
axes = fig._axes_scitex
if hasattr(axes, "flat"):
return list(axes.flat)
return [axes]
return []
def _find_scitex_axis(scitex_axes, mpl_ax):
"""Find the scitex axis wrapper that wraps the given mpl axis."""
for ax in scitex_axes:
if hasattr(ax, "_axes_mpl") and ax._axes_mpl is mpl_ax:
return ax
return None
def _infer_layout(axes_list, fig) -> Tuple[int, int]:
"""Infer nrows, ncols from axes positions."""
if not axes_list:
return 1, 1
# Check if using gridspec
if hasattr(fig, "_gridspecs") and fig._gridspecs:
gs = fig._gridspecs[0]
return gs.nrows, gs.ncols
# Fallback: guess from axes count
n = len(axes_list)
if n == 1:
return 1, 1
elif n == 2:
return 1, 2
elif n <= 4:
return 2, 2
else:
# Try to make it roughly square
import math
ncols = int(math.ceil(math.sqrt(n)))
nrows = int(math.ceil(n / ncols))
return nrows, ncols
def _color_to_hex(color) -> str:
"""Convert matplotlib color to hex string."""
try:
import matplotlib.colors as mcolors
rgb = mcolors.to_rgb(color)
return f"#{int(rgb[0] * 255):02x}{int(rgb[1] * 255):02x}{int(rgb[2] * 255):02x}"
except (ValueError, TypeError):
return "#ffffff"
def _extract_axes_style(mpl_ax) -> AxesStyle:
"""Extract style information from matplotlib axes."""
# Check grid visibility
grid_visible = False
try:
gridlines = mpl_ax.xaxis.get_gridlines()
if gridlines:
grid_visible = gridlines[0].get_visible()
except (AttributeError, IndexError):
pass
return AxesStyle(
facecolor=_color_to_hex(mpl_ax.get_facecolor()),
grid=grid_visible,
spines_visible={
"top": mpl_ax.spines["top"].get_visible(),
"right": mpl_ax.spines["right"].get_visible(),
"bottom": mpl_ax.spines["bottom"].get_visible(),
"left": mpl_ax.spines["left"].get_visible(),
},
)
def _text_to_annotation(text_obj) -> Optional[AnnotationModel]:
"""Convert matplotlib text object to AnnotationModel."""
text = text_obj.get_text()
if not text or not text.strip():
return None
pos = text_obj.get_position()
style = TextStyle(
fontsize=text_obj.get_fontsize(),
color=_color_to_hex(text_obj.get_color()),
ha=text_obj.get_ha(),
va=text_obj.get_va(),
rotation=text_obj.get_rotation(),
)
return AnnotationModel(
annotation_type="text",
text=text,
x=pos[0],
y=pos[1],
style=style,
)
def _extract_guides(mpl_ax) -> List[GuideModel]:
"""Extract guide lines (axhline, axvline) from axes."""
guides = []
# Check for horizontal lines
for line in mpl_ax.lines:
data = line.get_xydata()
if len(data) >= 2:
# Check if horizontal (y values same)
if data[0][1] == data[-1][1] and data[0][0] != data[-1][0]:
xlim = mpl_ax.get_xlim()
if (
abs(data[0][0] - xlim[0]) < 0.01
and abs(data[-1][0] - xlim[1]) < 0.01
):
guides.append(
GuideModel(
guide_type="axhline",
y=data[0][1],
color=_color_to_hex(line.get_color()),
linestyle=line.get_linestyle(),
linewidth=line.get_linewidth(),
)
)
# Check if vertical
elif data[0][0] == data[-1][0] and data[0][1] != data[-1][1]:
ylim = mpl_ax.get_ylim()
if (
abs(data[0][1] - ylim[0]) < 0.01
and abs(data[-1][1] - ylim[1]) < 0.01
):
guides.append(
GuideModel(
guide_type="axvline",
x=data[0][0],
color=_color_to_hex(line.get_color()),
linestyle=line.get_linestyle(),
linewidth=line.get_linewidth(),
)
)
return guides
def _history_entry_to_plot_model(
plot_id: str,
method_name: str,
tracked_dict: Dict,
kwargs: Dict,
) -> Optional[PlotModel]:
"""Convert a tracking history entry to PlotModel."""
# Map matplotlib methods to vis plot types
method_to_type = {
"plot": "line",
"scatter": "scatter",
"bar": "bar",
"barh": "barh",
"hist": "histogram",
"boxplot": "boxplot",
"violinplot": "violin",
"fill_between": "fill_between",
"errorbar": "errorbar",
"imshow": "imshow",
"contour": "contour",
"contourf": "contourf",
}
plot_type = method_to_type.get(method_name, method_name)
# Extract data from tracked_dict
data = {}
if "args" in tracked_dict:
args = tracked_dict["args"]
if method_name in ("plot", "scatter") and len(args) >= 2:
data["x"] = _array_to_list(args[0])
data["y"] = _array_to_list(args[1])
elif method_name == "bar" and len(args) >= 2:
data["x"] = _array_to_list(args[0])
data["height"] = _array_to_list(args[1])
elif method_name == "hist" and len(args) >= 1:
data["x"] = _array_to_list(args[0])
# Extract style from kwargs
style = PlotStyle()
if "color" in kwargs:
style.color = _color_to_hex(kwargs["color"]) if kwargs["color"] else None
if "linewidth" in kwargs or "lw" in kwargs:
style.linewidth = kwargs.get("linewidth") or kwargs.get("lw")
if "linestyle" in kwargs or "ls" in kwargs:
style.linestyle = kwargs.get("linestyle") or kwargs.get("ls")
if "marker" in kwargs:
style.marker = kwargs.get("marker")
if "alpha" in kwargs:
style.alpha = kwargs.get("alpha")
if "label" in kwargs:
style.label = kwargs.get("label")
return PlotModel(
plot_type=plot_type,
plot_id=plot_id,
data=data,
style=style,
)
def _array_to_list(arr) -> List:
"""Convert array-like to list for serialization."""
if hasattr(arr, "tolist"):
return arr.tolist()
elif isinstance(arr, (list, tuple)):
return list(arr)
return [arr]
def _is_array_like(obj) -> bool:
"""Check if object is array-like."""
return hasattr(obj, "__len__") and not isinstance(obj, (str, dict))
def _is_serializable(obj) -> bool:
"""Check if object is JSON serializable."""
import json
try:
json.dumps(obj)
return True
except (TypeError, ValueError):
return False
__all__ = [
"figure_to_vis_model",
"axes_to_vis_axes",
"tracking_to_plot_configs",
"collect_figure_data",
]
# EOF