@@ -81,14 +81,6 @@ def pathify(path):
8181 return Path (path + ext )
8282
8383
84- def _pytest_pyfunc_call (obj , pyfuncitem ):
85- testfunction = pyfuncitem .obj
86- funcargs = pyfuncitem .funcargs
87- testargs = {arg : funcargs [arg ] for arg in pyfuncitem ._fixtureinfo .argnames }
88- obj .result = testfunction (** testargs )
89- return True
90-
91-
9284def generate_test_name (item ):
9385 """
9486 Generate a unique name for the hash for this test.
@@ -100,6 +92,24 @@ def generate_test_name(item):
10092 return name
10193
10294
95+ def wrap_figure_interceptor (plugin , item ):
96+ """
97+ Intercept and store figures returned by test functions.
98+ """
99+ # Only intercept figures on marked figure tests
100+ if get_compare (item ) is not None :
101+
102+ # Use the full test name as a key to ensure correct figure is being retrieved
103+ test_name = generate_test_name (item )
104+
105+ def figure_interceptor (store , obj ):
106+ def wrapper (* args , ** kwargs ):
107+ store .return_value [test_name ] = obj (* args , ** kwargs )
108+ return wrapper
109+
110+ item .obj = figure_interceptor (plugin , item .obj )
111+
112+
103113def pytest_report_header (config , startdir ):
104114 import matplotlib
105115 import matplotlib .ft2font
@@ -286,6 +296,7 @@ def __init__(self,
286296 self ._generated_hash_library = {}
287297 self ._test_results = {}
288298 self ._test_stats = None
299+ self .return_value = {}
289300
290301 # https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
291302 # turn debug prints on only if "-vv" or more passed
@@ -608,13 +619,14 @@ def pytest_runtest_call(self, item): # noqa
608619 with plt .style .context (style , after_reset = True ), switch_backend (backend ):
609620
610621 # Run test and get figure object
622+ wrap_figure_interceptor (self , item )
611623 yield
612- fig = self .result
624+ test_name = generate_test_name (item )
625+ fig = self .return_value [test_name ]
613626
614627 if remove_text :
615628 remove_ticks_and_titles (fig )
616629
617- test_name = generate_test_name (item )
618630 result_dir = self .make_test_results_dir (item )
619631
620632 summary = {
@@ -678,10 +690,6 @@ def pytest_runtest_call(self, item): # noqa
678690 if summary ['status' ] == 'skipped' :
679691 pytest .skip (summary ['status_msg' ])
680692
681- @pytest .hookimpl (tryfirst = True )
682- def pytest_pyfunc_call (self , pyfuncitem ):
683- return _pytest_pyfunc_call (self , pyfuncitem )
684-
685693 def generate_summary_json (self ):
686694 json_file = self .results_dir / 'results.json'
687695 with open (json_file , 'w' ) as f :
@@ -733,13 +741,13 @@ class FigureCloser:
733741
734742 def __init__ (self , config ):
735743 self .config = config
744+ self .return_value = {}
736745
737746 @pytest .hookimpl (hookwrapper = True )
738747 def pytest_runtest_call (self , item ):
748+ wrap_figure_interceptor (self , item )
739749 yield
740750 if get_compare (item ) is not None :
741- close_mpl_figure (self .result )
742-
743- @pytest .hookimpl (tryfirst = True )
744- def pytest_pyfunc_call (self , pyfuncitem ):
745- return _pytest_pyfunc_call (self , pyfuncitem )
751+ test_name = generate_test_name (item )
752+ fig = self .return_value [test_name ]
753+ close_mpl_figure (fig )
0 commit comments