@@ -156,7 +156,7 @@ def ind_arr(shape, columns=False):
156156 q = create_full (F_shape , 0.0 , dtype )
157157
158158 # bathymetry
159- h = create_full (T_shape , 1 .0 , dtype ) # HACK init with 1
159+ h = create_full (T_shape , 0 .0 , dtype )
160160
161161 hu = create_full (U_shape , 0.0 , dtype )
162162 hv = create_full (V_shape , 0.0 , dtype )
@@ -165,7 +165,7 @@ def ind_arr(shape, columns=False):
165165 dvdx = create_full (F_shape , 0.0 , dtype )
166166
167167 # vector invariant form
168- H_at_f = create_full (F_shape , 0 .0 , dtype )
168+ H_at_f = create_full (F_shape , 1 .0 , dtype ) # HACK init with 1
169169
170170 # auxiliary variables for RK time integration
171171 e1 = create_full (T_shape , 0.0 , dtype )
@@ -205,15 +205,14 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205205 return bath * create_full (T_shape , 1.0 , dtype )
206206
207207 # set bathymetry
208- # h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
208+ h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly ).to_device (device )
209209 # steady state potential energy
210- # pe_offset = 0.5 * g * float( np.sum(h**2.0, all_axes)) / nx / ny
211- pe_offset = 0.5 * g * float (1.0 ) / nx / ny
210+ h2sum = np .sum (h ** 2.0 , all_axes ). to_device ()
211+ pe_offset = 0.5 * g * float (np . sum ( h2sum , all_axes ) ) / nx / ny
212212
213213 # compute time step
214214 alpha = 0.5
215- # h_max = float(np.max(h, all_axes))
216- h_max = float (1.0 )
215+ h_max = float (np .max (h , all_axes ).to_device ())
217216 c = (g * h_max ) ** 0.5
218217 dt = alpha * dx / c
219218 dt = t_export / int (math .ceil (t_export / dt ))
@@ -253,10 +252,11 @@ def rhs(u, v, e):
253252 H_at_f [- 1 , 1 :- 1 ] = 0.5 * (H [- 1 , 1 :] + H [- 1 , :- 1 ])
254253 H_at_f [1 :- 1 , 0 ] = 0.5 * (H [1 :, 0 ] + H [:- 1 , 0 ])
255254 H_at_f [1 :- 1 , - 1 ] = 0.5 * (H [1 :, - 1 ] + H [:- 1 , - 1 ])
256- H_at_f [0 , 0 ] = H [0 , 0 ]
257- H_at_f [0 , - 1 ] = H [0 , - 1 ]
258- H_at_f [- 1 , 0 ] = H [- 1 , 0 ]
259- H_at_f [- 1 , - 1 ] = H [- 1 , - 1 ]
255+ # NOTE causes gpu.memcpy error, non-identity layout
256+ # H_at_f[0, 0] = H[0, 0]
257+ # H_at_f[0, -1] = H[0, -1]
258+ # H_at_f[-1, 0] = H[-1, 0]
259+ # H_at_f[-1, -1] = H[-1, -1]
260260
261261 # potential vorticity
262262 dudy [:, 1 :- 1 ] = (u [:, 1 :] - u [:, :- 1 ]) / dy
@@ -346,41 +346,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
346346 t = i * dt
347347
348348 if t >= next_t_export - 1e-8 :
349- if device :
350- # FIXME gpu.memcpy to host requires identity layout
351- # FIXME reduction on gpu
352- elev_max = 0
353- u_max = 0
354- q_max = 0
355- diff_e = 0
356- diff_v = 0
357- total_pe = 0
358- total_ke = 0
359- else :
360- _elev_max = np .max (e , all_axes )
361- _u_max = np .max (u , all_axes )
362- _q_max = np .max (q , all_axes )
363- _total_v = np .sum (e + h , all_axes )
364-
365- # potential energy
366- _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
367- _total_pe = np .sum (_pe , all_axes )
368-
369- # kinetic energy
370- u2 = u * u
371- v2 = v * v
372- u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
373- v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
374- _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
375- _total_ke = np .sum (_ke , all_axes )
376-
377- total_pe = float (_total_pe ) * dx * dy
378- total_ke = float (_total_ke ) * dx * dy
379- total_e = total_ke + total_pe
380- elev_max = float (_elev_max )
381- u_max = float (_u_max )
382- q_max = float (_q_max )
383- total_v = float (_total_v ) * dx * dy
349+ sync ()
350+ # NOTE must precompute reduction operands to single field
351+ H_tmp = e + h
352+ # potential energy
353+ _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
354+ # kinetic energy
355+ u2 = u * u
356+ v2 = v * v
357+ u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
358+ v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
359+ _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
360+ sync ()
361+ _elev_max = np .max (e , all_axes ).to_device ()
362+ # NOTE max(u) segfaults, shape (n+1, n) too large for tiling
363+ _u_max = np .max (u [1 :, :], all_axes ).to_device ()
364+ _q_max = np .max (q [1 :, 1 :], all_axes ).to_device ()
365+ _total_v = np .sum (H_tmp , all_axes ).to_device ()
366+ _total_pe = np .sum (_pe , all_axes ).to_device ()
367+ _total_ke = np .sum (_ke , all_axes ).to_device ()
368+
369+ total_pe = float (_total_pe ) * dx * dy
370+ total_ke = float (_total_ke ) * dx * dy
371+ total_e = total_ke + total_pe
372+ elev_max = float (_elev_max )
373+ u_max = float (_u_max )
374+ q_max = float (_q_max )
375+ total_v = float (_total_v ) * dx * dy
376+
377+ diff_e = 0
378+ diff_v = 0
384379
385380 if i_export == 0 :
386381 initial_v = total_v
@@ -415,40 +410,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
415410 duration = time_mod .perf_counter () - tic
416411 info (f"Duration: { duration :.2f} s" )
417412
418- if device :
419- # FIXME gpu.memcpy to host requires identity layout
420- # FIXME reduction on gpu
421- pass
422- else :
423- e_exact = exact_solution (
424- t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
425- )[2 ]
426- err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
427- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
428- info (f"L2 error: { err_L2 :7.15e} " )
429-
430- if nx < 128 or ny < 128 :
431- info ("Skipping correctness test due to small problem size." )
432- elif not benchmark_mode :
433- tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
434- assert (
435- diff_e < tolerance_ene
436- ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
437- if nx == 128 and ny == 128 :
438- if datatype == "f32" :
439- assert numpy .allclose (
440- err_L2 , 4.3127859e-05 , rtol = 1e-5
441- ), "L2 error does not match"
442- else :
443- assert numpy .allclose (
444- err_L2 , 4.315799035627906e-05
445- ), "L2 error does not match"
413+ e_exact = exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )[
414+ 2
415+ ].to_device (device )
416+ err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
417+ err2sum = np .sum (err2 , all_axes ).to_device ()
418+ err_L2 = math .sqrt (float (err2sum ))
419+ info (f"L2 error: { err_L2 :7.15e} " )
420+
421+ if nx < 128 or ny < 128 :
422+ info ("Skipping correctness test due to small problem size." )
423+ elif not benchmark_mode :
424+ tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
425+ assert (
426+ diff_e < tolerance_ene
427+ ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
428+ if nx == 128 and ny == 128 :
429+ if datatype == "f32" :
430+ assert numpy .allclose (
431+ err_L2 , 4.3127859e-05 , rtol = 1e-5
432+ ), "L2 error does not match"
446433 else :
447- tolerance_l2 = 1e-4
448- assert (
449- err_L2 < tolerance_l2
450- ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
451- info ("SUCCESS" )
434+ assert numpy .allclose (
435+ err_L2 , 4.315799035627906e-05
436+ ), "L2 error does not match"
437+ else :
438+ tolerance_l2 = 1e-4
439+ assert (
440+ err_L2 < tolerance_l2
441+ ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
442+ info ("SUCCESS" )
452443
453444 fini ()
454445
0 commit comments