Skip to content

Commit ca5c9f4

Browse files
committed
Merge pull request #43 from haoxingz/magic_test
Magic test
2 parents c230b79 + 2e3bfd9 commit ca5c9f4

File tree

2 files changed

+140
-28
lines changed

2 files changed

+140
-28
lines changed

pymatbridge/matlab_magic.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
has_io = False
2929
no_io_str = "Must have h5py and scipy.io to perform i/o"
3030
no_io_str += "operations with the Matlab session"
31-
31+
3232
from IPython.core.displaypub import publish_display_data
3333
from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic,
3434
line_cell_magic, needs_local_scope)
@@ -39,7 +39,7 @@
3939

4040
import pymatbridge as pymat
4141

42-
42+
4343
class MatlabInterperterError(RuntimeError):
4444
"""
4545
Some error occurs while matlab is running
@@ -52,7 +52,7 @@ def __unicode__(self):
5252
s = "Failed to parse and evaluate line %r.\n Matlab error message: %r"%\
5353
(self.line, self.err)
5454
return s
55-
55+
5656
if PY3:
5757
__str__ = __unicode__
5858
else:
@@ -66,25 +66,54 @@ def loadmat(fname):
6666
"""
6767

6868
f = h5py.File(fname)
69-
data = f.values()[0][:]
70-
if len(data.dtype) > 0:
71-
# must be complex data
72-
data = data['real'] + 1j * data['imag']
73-
return data
69+
70+
for var_name in f.iterkeys():
71+
if isinstance(f[var_name], h5py.Dataset):
72+
# Currently only supports numerical array
73+
data = f[var_name].value
74+
if len(data.dtype) > 0:
75+
# must be complex data
76+
data = data['real'] + 1j * data['imag']
77+
return np.squeeze(data.T)
78+
79+
elif isinstance(f[var_name], h5py.Group):
80+
data = {}
81+
for mem_name in f[var_name].iterkeys():
82+
if isinstance(f[var_name][mem_name], h5py.Dataset):
83+
# Check if the dataset is a string
84+
attr = h5py.AttributeManager(f[var_name][mem_name])
85+
if (attr.__getitem__('MATLAB_class') == 'char'):
86+
is_string = True
87+
else:
88+
is_string = False
89+
90+
data[mem_name] = f[var_name][mem_name].value
91+
data[mem_name] = np.squeeze(data[mem_name].T)
92+
93+
if is_string:
94+
result = ''
95+
for asc in data[mem_name]:
96+
result += chr(asc)
97+
data[mem_name] = result
98+
else:
99+
# Currently doesn't support nested struct
100+
pass
101+
102+
return data
74103

75104

76105
def matlab_converter(matlab, key):
77106
"""
78107
79108
Reach into the matlab namespace and get me the value of the variable
80-
109+
81110
"""
82111
tempdir = tempfile.gettempdir()
83112
# We save as hdf5 in the matlab session, so that we can grab large
84113
# variables:
85114
matlab.run_code("save('%s/%s.mat','%s','-v7.3')"%(tempdir, key, key),
86115
maxtime=matlab.maxtime)
87-
116+
88117
return loadmat('%s/%s.mat'%(tempdir, key))
89118

90119

@@ -113,17 +142,17 @@ def __init__(self, shell,
113142
maxtime : float
114143
The maximal time to wait for responses for matlab (in seconds).
115144
Default: 10 seconds.
116-
145+
117146
pyconverter : callable
118147
To be called on matlab variables returning into the ipython
119148
namespace
120-
149+
121150
matlab_converter : callable
122-
To be called on values in ipython namespace before
151+
To be called on values in ipython namespace before
123152
assigning to variables in matlab.
124153
125154
cache_display_data : bool
126-
If True, the published results of the final call to R are
155+
If True, the published results of the final call to R are
127156
cached in the variable 'display_cache'.
128157
129158
"""
@@ -133,7 +162,7 @@ def __init__(self, shell,
133162
self.Matlab = pymat.Matlab(matlab, maxtime=maxtime)
134163
self.Matlab.start()
135164
self.pyconverter = pyconverter
136-
self.matlab_converter = matlab_converter
165+
self.matlab_converter = matlab_converter
137166

138167
def __del__(self):
139168
"""shut down the Matlab server when the object dies.
@@ -154,9 +183,9 @@ def eval(self, line):
154183
if run_dict['success'] == 'false':
155184
raise MatlabInterperterError(line, run_dict['content']['stdout'])
156185

157-
# This is the matlab stdout:
186+
# This is the matlab stdout:
158187
return run_dict
159-
188+
160189
@magic_arguments()
161190
@argument(
162191
'-i', '--input', action='append',
@@ -180,7 +209,7 @@ def matlab(self, line, cell=None, local_ns=None):
180209
"""
181210
182211
Execute code in matlab
183-
212+
184213
"""
185214
args = parse_argstring(self.matlab, line)
186215

@@ -210,7 +239,7 @@ def matlab(self, line, cell=None, local_ns=None):
210239
except KeyError:
211240
val = self.shell.user_ns[input]
212241
# We save these input arguments into a .mat file:
213-
tempdir = tempfile.gettempdir()
242+
tempdir = tempfile.gettempdir()
214243
sio.savemat('%s/%s.mat'%(tempdir, input),
215244
eval("dict(%s=val)"%input), oned_as='row')
216245

@@ -219,7 +248,7 @@ def matlab(self, line, cell=None, local_ns=None):
219248

220249
else:
221250
raise RuntimeError(no_io_str)
222-
251+
223252
text_output = ''
224253
#imgfiles = []
225254

@@ -234,14 +263,14 @@ def matlab(self, line, cell=None, local_ns=None):
234263
e_s += "\n-----------------------"
235264
e_s += "\nAre you sure Matlab is started?"
236265
raise RuntimeError(e_s)
237-
238-
266+
267+
239268

240269
text_output += result_dict['content']['stdout']
241270
# Figures get saved by matlab in reverse order...
242271
imgfiles = result_dict['content']['figures'][::-1]
243272
data_dir = result_dict['content']['datadir']
244-
273+
245274
display_data = []
246275
if text_output:
247276
display_data.append(('MatlabMagic.matlab',
@@ -251,7 +280,7 @@ def matlab(self, line, cell=None, local_ns=None):
251280
if len(imgf):
252281
# Store the path to the directory so that you can delete it
253282
# later on:
254-
image = open(imgf, 'rb').read()
283+
image = open(imgf, 'rb').read()
255284
display_data.append(('MatlabMagic.matlab',
256285
{'image/png': image}))
257286

@@ -261,24 +290,24 @@ def matlab(self, line, cell=None, local_ns=None):
261290
# Delete the temporary data files created by matlab:
262291
if len(data_dir):
263292
rmtree(data_dir)
264-
293+
265294
if args.output:
266295
if has_io:
267296
for output in ','.join(args.output).split(','):
268297
self.shell.push({output:self.matlab_converter(self.Matlab,
269298
output)})
270299
else:
271300
raise RuntimeError(no_io_str)
272-
273-
301+
302+
274303
_loaded = False
275304
def load_ipython_extension(ip, **kwargs):
276305
"""Load the extension in IPython."""
277306
global _loaded
278307
if not _loaded:
279308
ip.register_magics(MatlabMagics(ip, **kwargs))
280309
_loaded = True
281-
310+
282311
def unload_ipython_extension(ip):
283312
global _loaded
284313
if _loaded:

pymatbridge/tests/test_magic.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pymatbridge as pymat
2+
import IPython
3+
4+
import numpy.testing as npt
5+
6+
class TestMagic:
7+
8+
# Create an IPython shell and load Matlab magic
9+
@classmethod
10+
def setup_class(cls):
11+
cls.ip = IPython.InteractiveShell()
12+
cls.ip.run_cell('import random')
13+
cls.ip.run_cell('import numpy as np')
14+
pymat.load_ipython_extension(cls.ip)
15+
16+
# Unload the magic, shut down Matlab
17+
@classmethod
18+
def teardown_class(cls):
19+
pymat.unload_ipython_extension(cls.ip)
20+
21+
22+
# Test single operation on different data structures
23+
def test_cell_magic_number(self):
24+
# A double precision real number
25+
self.ip.run_cell("a = np.float64(random.random())")
26+
self.ip.run_cell_magic('matlab', '-i a -o b', 'b = a*2;')
27+
npt.assert_almost_equal(self.ip.user_ns['b'],
28+
self.ip.user_ns['a']*2, decimal=7)
29+
30+
# A complex number
31+
self.ip.run_cell("x = 3.34+4.56j")
32+
self.ip.run_cell_magic('matlab', '-i x -o y', 'y = x*(11.35 - 23.098j)')
33+
self.ip.run_cell("res = x*(11.35 - 23.098j)")
34+
npt.assert_almost_equal(self.ip.user_ns['y'],
35+
self.ip.user_ns['res'], decimal=7)
36+
37+
38+
def test_cell_magic_array(self):
39+
# Random array multiplication
40+
self.ip.run_cell("val1 = np.random.random_sample((3,3))")
41+
self.ip.run_cell("val2 = np.random.random_sample((3,3))")
42+
self.ip.run_cell("respy = np.dot(val1, val2)")
43+
self.ip.run_cell_magic('matlab', '-i val1,val2 -o resmat',
44+
'resmat = val1 * val2')
45+
npt.assert_almost_equal(self.ip.user_ns['resmat'],
46+
self.ip.user_ns['respy'], decimal=7)
47+
48+
49+
def test_line_magic(self):
50+
# Some operation in Matlab
51+
self.ip.run_line_magic('matlab', 'a = [1 2 3]')
52+
self.ip.run_line_magic('matlab', 'res = a*2')
53+
# Get the result back to Python
54+
self.ip.run_cell_magic('matlab', '-o actual', 'actual = res')
55+
56+
self.ip.run_cell("expected = np.array([2, 4, 6])")
57+
npt.assert_almost_equal(self.ip.user_ns['actual'],
58+
self.ip.user_ns['expected'], decimal=7)
59+
60+
def test_figure(self):
61+
# Just make a plot to get more testing coverage
62+
self.ip.run_line_magic('matlab', 'plot([1 2 3])')
63+
64+
65+
def test_matrix(self):
66+
self.ip.run_cell("in_array = np.array([[1,2,3], [4,5,6]])")
67+
self.ip.run_cell_magic('matlab', '-i in_array -o out_array',
68+
'out_array = in_array;')
69+
npt.assert_almost_equal(self.ip.user_ns['out_array'],
70+
self.ip.user_ns['in_array'],
71+
decimal=7)
72+
73+
# Matlab struct type should be converted to a Python dict
74+
def test_struct(self):
75+
self.ip.run_cell('num = 2.567')
76+
self.ip.run_cell('num_array = np.array([1.2,3.4,5.6])')
77+
self.ip.run_cell('str = "Hello World"')
78+
self.ip.run_cell_magic('matlab', '-i num,num_array,str -o obj',
79+
'obj.num = num;obj.num_array = num_array;obj.str = str;')
80+
npt.assert_equal(isinstance(self.ip.user_ns['obj'], dict), True)
81+
npt.assert_equal(self.ip.user_ns['obj']['num'], self.ip.user_ns['num'])
82+
npt.assert_equal(self.ip.user_ns['obj']['num_array'], self.ip.user_ns['num_array'])
83+
npt.assert_equal(self.ip.user_ns['obj']['str'], self.ip.user_ns['str'])

0 commit comments

Comments
 (0)