Source code for brian2wasm.device

"""
Module implementing the WASM/JS "standalone" device.
"""
import os
import platform
import re
import shutil
import tempfile
import time
from collections import Counter

import numpy as np

from brian2.units import second
from brian2.core.namespace import get_local_namespace
from brian2.core.preferences import prefs, BrianPreference
from brian2.synapses import Synapses
from brian2.utils.logger import get_logger
from brian2.utils.filetools import in_directory
from brian2.devices import all_devices
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice, CPPWriter
from brian2.utils.filetools import ensure_directory


logger = get_logger(__name__)

prefs.register_preferences(
    'devices.wasm_standalone',
    'Preferences for the WebAsm backend',
    emsdk_directory=BrianPreference(
        default="",
        docs="""
            Absolute path to the *emsdk* installation. Leave empty to use the
            EMSDK/CONDA_EMSDK_DIR environment variables or an already-activated
            emsdk in your shell.
            """,
    ),
    emsdk_version=BrianPreference(
        default="latest",
        docs="""
            Version string passed to ``emsdk activate`` (e.g. ``"3.1.56"``).
            Ignored when *emsdk_directory* is empty and the SDK is pre-activated.
            """,
    ),
    emcc_compile_args=BrianPreference(
        default=["-w"],
        docs="""
            Extra flags appended to every *emcc* **compile** command.
            Example: ``["-O3", "-sASSERTIONS"]``.
            """,
    ),
    emcc_link_args=BrianPreference(
        default=[],
        docs="""
            Extra flags appended to the final *emcc* **link** command that produces
            ``wasm_module.js`` / ``.wasm``.
            Example: ``["-sEXPORT_ES6", "-sEXPORTED_RUNTIME_METHODS=['cwrap']"]``.
            """,
    ),
)


DEFAULT_HTML_CONTENT = {'title': 'Brian simulation',
                        'h1': '',
                        'h2': '',
                        'description': '',
                        'canvas_width': '95%',
                        'canvas_height': '500px'}

[docs] class WASMStandaloneDevice(CPPStandaloneDevice): """ The `Device` used for WASM simulations. """ def __init__(self, *args, **kwds): """ Initialize the WASM standalone device. This method prepares the device by setting up internal attributes and delegating initialization to the parent ``CPPStandaloneDevice``. Parameters ---------- *args : tuple Positional arguments passed to the parent ``CPPStandaloneDevice``. **kwds : dict Keyword arguments passed to the parent ``CPPStandaloneDevice``. Raises ------ None Returns ------- None Initializes internal state; does not return a value. """ self.transfer_results = None super(WASMStandaloneDevice, self).__init__(*args, **kwds)
[docs] def transfer_only(self, variableviews): """ Mark variables for transfer from WASM to JavaScript. This method specifies which simulation variables should be available in JavaScript after the simulation completes. Parameters ---------- variableviews : list List of ``VariableView`` objects to be transferred. Raises ------ AssertionError If transfer variables are already set before calling this method. Returns ------- None Stores the selected variables for later transfer; does not return a value. """ assert self.transfer_results is None self.transfer_results = [] for variableview in variableviews: self.transfer_results.append(variableview.variable)
[docs] def activate(self, *args, **kwargs): """ Activate the WASM standalone device for simulation. This method overrides template configuration and ensures WASM-specific headers are included in the generated code. Parameters ---------- *args : tuple Positional arguments passed to the parent activate method. **kwargs : dict Keyword arguments passed to the parent activate method. Raises ------ None Returns ------- None Configures the device and modifies build templates; does not return a value. """ super(WASMStandaloneDevice, self).activate(*args, **kwargs) # Overwrite the templater to prefer our templates self.code_object_class().templater = self.code_object_class().templater.derive('brian2wasm') if '<emscripten.h>' not in prefs.codegen.cpp.headers: prefs.codegen.cpp.headers += ['<emscripten.h>']
[docs] def generate_objects_source( self, writer, arange_arrays, synapses, static_array_specs, networks, timed_arrays, ): """ Generate the main C++ source file for WASM compilation. This method produces the core simulation code, including objects, arrays, and transfer variables, and writes it to ``objects.*`` files. Parameters ---------- writer : CodeWriter Object for writing generated code. arange_arrays : dict Specifications for arange arrays. synapses : set Set of ``Synapses`` objects in the simulation. static_array_specs : dict Specifications for static arrays. networks : set Set of ``Network`` objects in the simulation. timed_arrays : dict Specifications for timed arrays. Raises ------ IOError If writing the generated code to files fails. Returns ------- None Generates source files on disk; does not return a value. """ arr_tmp = self.code_object_class().templater.objects( None, None, array_specs=self.arrays, dynamic_array_specs=self.dynamic_arrays, dynamic_array_2d_specs=self.dynamic_arrays_2d, zero_arrays=self.zero_arrays, arange_arrays=arange_arrays, synapses=synapses, clocks=self.clocks, static_array_specs=static_array_specs, networks=networks, get_array_filename=self.get_array_filename, get_array_name=self.get_array_name, profiled_codeobjects=self.profiled_codeobjects, code_objects=list(self.code_objects.values()), timed_arrays=timed_arrays, transfer_results=self.transfer_results, ) writer.write("objects.*", arr_tmp)
[docs] def generate_makefile(self, writer, compiler, compiler_flags, linker_flags, nb_threads, debug): """ Generate a platform-specific makefile for Emscripten compilation. This method configures compiler and linker flags, resolves SDK paths, and writes a makefile tailored for WASM builds. Parameters ---------- writer : CodeWriter Object for writing generated files. compiler : str Compiler name (typically ``emcc``). compiler_flags : str Compiler flags to apply. linker_flags : str Linker flags to apply. nb_threads : int Number of threads (unused for WASM). debug : bool Whether to include debug symbols. Raises ------ RuntimeError If Emscripten paths or build options are invalid. Returns ------- None Writes the makefile to disk; does not return a value. """ preloads = ' '.join(f'--preload-file static_arrays/{static_array}' for static_array in sorted(self.static_arrays.keys())) rm_cmd = 'rm $(OBJS) $(PROGRAM) $(DEPS)' if debug: compiler_debug_flags = '-g -DDEBUG' linker_debug_flags = '-g' else: compiler_debug_flags = '' linker_debug_flags = '' source_files = ' '.join(sorted(writer.source_files)) preamble_file = os.path.join(os.path.dirname(__file__), 'templates', 'pre.js') prefs.devices.wasm_standalone.emsdk_directory = ( prefs.devices.wasm_standalone.emsdk_directory or os.environ.get("EMSDK") or os.environ.get("CONDA_EMSDK_DIR") ) emsdk_path = prefs.devices.wasm_standalone.emsdk_directory emsdk_version = prefs.devices.wasm_standalone.emsdk_version if not emsdk_path: # Check whether EMSDK is already activated if not (os.environ.get("EMSDK", "")) or os.environ["EMSDK"] not in os.environ["PATH"]: raise ValueError("Please provide the path to the emsdk directory in the preferences") if os.name == 'nt': makefile_tmp = self.code_object_class().templater.win_makefile(None, None, source_files=source_files, header_files=' '.join(sorted(writer.header_files)), compiler_flags=compiler_flags, compiler_debug_flags=compiler_debug_flags, linker_debug_flags=linker_debug_flags, linker_flags=linker_flags, preloads=preloads, preamble_file=preamble_file, rm_cmd=rm_cmd, emsdk_path=emsdk_path, emsdk_version=emsdk_version) else: makefile_tmp = self.code_object_class().templater.makefile(None, None, source_files=source_files, header_files=' '.join(sorted(writer.header_files)), compiler_flags=compiler_flags, compiler_debug_flags=compiler_debug_flags, linker_debug_flags=linker_debug_flags, linker_flags=linker_flags, preloads=preloads, preamble_file=preamble_file, rm_cmd=rm_cmd, emsdk_path=emsdk_path, emsdk_version=emsdk_version) outputfile_name = 'win_makefile' if os.name == 'nt' else 'makefile' writer.write(outputfile_name, makefile_tmp)
[docs] def copy_source_files(self, writer, directory): """ Copy JavaScript runtime files to the build directory. This method copies required JavaScript files (``worker.js``, ``brian.js``) and optionally a custom ``index.html`` into the build folder. Parameters ---------- writer : CodeWriter Object containing source file information. directory : str Target directory for copied files. Raises ------ IOError If copying files fails. Returns ------- None Populates the build directory with JavaScript runtime files. """ super(WASMStandaloneDevice, self).copy_source_files(writer, directory) shutil.copy(os.path.join(os.path.dirname(__file__), 'templates', 'worker.js'), directory) shutil.copy(os.path.join(os.path.dirname(__file__), 'templates', 'brian.js'), directory) if self.build_options['html_file']: shutil.copy(self.build_options['html_file'], os.path.join(directory, 'index.html'))
[docs] def get_report_func(self, report): """ Generate C++ code for simulation progress reporting. This method produces source code that reports simulation progress to the console or forwards updates to JavaScript via ``EM_ASM``. Parameters ---------- report : str or None Type of progress reporting: None, 'text', 'stdout', 'stderr', or custom C++ code. Raises ------ ValueError If the report type is unsupported. Returns ------- str The generated C++ source code for progress reporting. """ # Code for a progress reporting function standard_code = """ std::string _format_time(float time_in_s) { float divisors[] = {24*60*60, 60*60, 60, 1}; char letters[] = {'d', 'h', 'm', 's'}; float remaining = time_in_s; std::string text = ""; int time_to_represent; for (int i =0; i < sizeof(divisors)/sizeof(float); i++) { time_to_represent = int(remaining / divisors[i]); remaining -= time_to_represent * divisors[i]; if (time_to_represent > 0 || text.length()) { if(text.length() > 0) { text += " "; } text += (std::to_string(time_to_represent)+letters[i]); } } //less than one second if(text.length() == 0) { text = "< 1s"; } return text; } void report_progress(const double elapsed, const double completed, const double start, const double duration) { // Send progress to javascript EM_ASM({ (postMessage({ type: 'progress', elapsed: $0, completed: $1, start: $2, duration: $3})); }, elapsed, completed, start, duration); if (completed == 0.0) { %STREAMNAME% << "Starting simulation at t=" << start << " s for duration " << duration << " s"; } else { %STREAMNAME% << completed*duration << " s (" << (int)(completed*100.) << "%) simulated in " << _format_time(elapsed) << " (" << elapsed << "s)"; if (completed < 1.0) { const int remaining = (int)((1-completed)/completed*elapsed+0.5); %STREAMNAME% << ", estimated " << _format_time(remaining) << " remaining."; } } %STREAMNAME% << std::endl << std::flush; } """ if report is None: report_func = '' elif report == 'text' or report == 'stdout': report_func = standard_code.replace('%STREAMNAME%', 'std::cout') elif report == 'stderr': report_func = standard_code.replace('%STREAMNAME%', 'std::cerr') elif isinstance(report, str): report_func = """ void report_progress(const double elapsed, const double completed, const double start, const double duration) { %REPORT% } """.replace('%REPORT%', report) else: raise TypeError("report argument has to be either 'text', " "'stdout', 'stderr', or the code for a report " "function") return report_func
[docs] def network_run(self, net, duration, report=None, report_period=10*second, namespace=None, profile=None, level=0, **kwds): """ Execute a Brian2 network simulation for the WASM backend. This method organizes network objects, generates C++ execution code, and triggers the build if ``build_on_run`` is enabled. Parameters ---------- net : Network The Brian2 network to simulate. duration : Quantity Duration of the simulation (must be non-negative). report : str or None, optional Progress reporting mode. Default is None. report_period : Quantity, optional Interval between progress reports. Default is 10*second. namespace : dict, optional Local namespace for variable resolution. Default is None. profile : bool, optional Whether to enable profiling. Default is None. level : int, optional Stack level for namespace detection. Default is 0. **kwds : dict Additional keyword arguments. Raises ------ ValueError If duration is negative. NotImplementedError If multiple incompatible report functions are used. RuntimeError If the network was already built and run. Returns ------- None Prepares and builds the simulation; does not return a value. """ if duration < 0: raise ValueError( f"Function 'run' expected a non-negative duration but got '{duration}'" ) self.networks.add(net) if kwds: logger.warn(('Unsupported keyword argument(s) provided for run: ' '%s') % ', '.join(kwds.keys())) # We store this as an instance variable for later access by the # `code_object` method self.enable_profiling = profile # Allow setting `profile` in the `set_device` call (used e.g. in brian2cuda # SpeedTest configurations) if profile is None: self.enable_profiling = self.build_options.get('profile', False) all_objects = net.sorted_objects net._clocks = {obj.clock for obj in all_objects} t_end = net.t+duration for clock in net._clocks: clock.set_interval(net.t, t_end) # Get the local namespace if namespace is None: namespace = get_local_namespace(level=level+2) net.before_run(namespace) self.synapses |= {s for s in net.objects if isinstance(s, Synapses)} self.clocks.update(net._clocks) net.t_ = float(t_end) # TODO: remove this horrible hack for clock in self.clocks: if clock.name=='clock': clock._name = '_clock' # Extract all the CodeObjects # Note that since we ran the Network object, these CodeObjects will be sorted into the right # running order, assuming that there is only one clock code_objects = [] for obj in all_objects: if obj.active: for codeobj in obj._code_objects: code_objects.append((obj.clock, codeobj)) report_func = self.get_report_func(report) if report_func != '': if self.report_func != '' and report_func != self.report_func: raise NotImplementedError("The C++ standalone device does not " "support multiple report functions, " "each run has to use the same (or " "none).") self.report_func = report_func if report_func: report_call = 'report_progress' else: report_call = 'NULL' # Generate the updaters run_lines = [f'{net.name}.clear();'] all_clocks = set() for clock, codeobj in code_objects: run_lines.append(f'{net.name}.add(&{clock.name}, _run_{codeobj.name});') all_clocks.add(clock) # Under some rare circumstances (e.g. a NeuronGroup only defining a # subexpression that is used by other groups (via linking, or recorded # by a StateMonitor) *and* not calculating anything itself *and* using a # different clock than all other objects) a clock that is not used by # any code object should nevertheless advance during the run. We include # such clocks without a code function in the network. for clock in net._clocks: if clock not in all_clocks: run_lines.append(f'{net.name}.add(&{clock.name}, NULL);') run_lines.extend(self.code_lines['before_network_run']) if not self.run_args_applied: run_lines.append('set_from_command_line(args);') self.run_args_applied = True run_lines.append(f'{net.name}.run({float(duration)!r}, {report_call}, {float(report_period)!r});') run_lines.extend(self.code_lines['after_network_run']) self.main_queue.append(('run_network', (net, run_lines))) net.after_run() # Manually set the cache for the clocks, simulation scripts might # want to access the time (which has been set in code and is therefore # not accessible by the normal means until the code has been built and # run) for clock in net._clocks: self.array_cache[clock.variables['timestep']] = np.array([clock._i_end]) self.array_cache[clock.variables['t']] = np.array([clock._i_end * clock.dt_]) if self.build_on_run: if self.has_been_run: raise RuntimeError("The network has already been built and run " "before. Use set_device with " "build_on_run=False and an explicit " "device.build call to use multiple run " "statements with this device.") self.build(direct_call=False, **self.build_options)
[docs] def run(self, directory, results_directory, with_output, run_args): """ Execute the compiled WASM simulation in a browser environment. This method launches the simulation using ``emrun`` and provides browser-based progress reporting and visualization. Parameters ---------- directory : str Build directory containing compiled files. results_directory : str Directory to store simulation results. with_output : bool Whether to forward stdout/stderr output. run_args : list Extra command-line arguments for the execution environment. Raises ------ RuntimeError If the server cannot be launched or required files are missing. Returns ------- None Runs the simulation in a browser; does not return a value. """ html_file = self.build_options['html_file'] html_content = self.build_options['html_content'] if html_file is None: import __main__ html_file = os.path.splitext(__main__.__file__)[0] + '.html' if not os.path.exists(html_file): if html_content is None: html_content = dict(DEFAULT_HTML_CONTENT) else: for key in html_content: if key not in DEFAULT_HTML_CONTENT: raise KeyError(f"Key '{key} is not a valid key for html_content. Allowed keys: {', '.join(DEFAULT_HTML_CONTENT.keys())}") for key in DEFAULT_HTML_CONTENT: if key not in html_content: html_content[key] = DEFAULT_HTML_CONTENT[key] html_file = os.path.join(self.project_dir, 'index.html') # Create HTML file from template in code directory html_tmp = self.code_object_class().templater.html_template(None, None, **html_content) with open(html_file, 'w') as f: f.write(html_tmp) else: # HTML file exists, copy it to the project directory shutil.copy(html_file, os.path.join(self.project_dir, 'index.html')) with in_directory(directory): if os.environ.get('BRIAN2WASM_NO_SERVER','0') == '1': print("Skipping server startup (--no-server flag set)") return emsdk_path = prefs.devices.wasm_standalone.emsdk_directory os.environ['EMSDK_QUIET'] = '1' if platform.system() == "Windows": cmd_line = f'cmd.exe /C "call {emsdk_path}\\emsdk_env.bat & emrun index.html"' else: run_cmd = ['source', f'{emsdk_path}/emsdk_env.sh', '&&', 'emrun', 'index.html'] cmd_line = f"/bin/bash -c '{' '.join(run_cmd + run_args)}'" start_time = time.time() os.system(cmd_line) self.timers['run_binary'] = time.time() - start_time
[docs] def build(self, html_file=None, html_content=None, **kwds): """ Build the project for the WASM backend. This method orchestrates the full build pipeline from code generation to Emscripten compilation and optional execution. Parameters ---------- html_file : str, optional Path to a custom HTML template file. html_content : dict, optional Dictionary of HTML template variables. directory : str, optional Target build directory. Defaults to "output". results_directory : str, optional Sub-folder for runtime results. Defaults to "results". compile : bool, optional Whether to compile sources with ``emcc``. Default is True. run : bool, optional Whether to run the generated bundle. Default is True. debug : bool, optional Whether to include debug flags. Default is False. clean : bool, optional Whether to clear old build artifacts. Default is False. with_output : bool, optional Whether to forward stdout/stderr. Default is True. additional_source_files : list of str, optional Extra ``.cpp`` files to include. run_args : list of str, optional Additional runtime arguments. direct_call : bool, optional True when called directly; False if triggered automatically. **kwds : dict Reserved for future options. Raises ------ RuntimeError If build state is invalid or already executed. TypeError If results_directory is absolute. ValueError If invalid options are passed (e.g., negative threads). Returns ------- None Produces a build directory with compiled WASM/HTML output. """ self.build_options.update({'html_file': html_file, 'html_content': html_content}) direct_call = kwds.get('direct_call', True) additional_source_files = kwds.get('additional_source_files', []) run_args = kwds.get('run_args', []) directory = kwds.get('directory') or tempfile.mkdtemp(prefix="brian_standalone_") run = kwds.get('run', True) debug = kwds.get('debug', False) clean = kwds.get('clean', False) with_output = kwds.get('with_output', True) results_directory = kwds.get('results_directory', 'results') compile = kwds.get('compile', True) if self.build_on_run and direct_call: raise RuntimeError( "You used set_device with build_on_run=True " "(the default option), which will automatically " "build the simulation at the first encountered " "run call - do not call device.build manually " "in this case. If you want to call it manually, " "e.g. because you have multiple run calls, use " "set_device with build_on_run=False." ) if self.has_been_run: raise RuntimeError( "The network has already been built and run " "before. To build several simulations in " 'the same script, call "device.reinit()" ' 'and "device.activate()". Note that you ' "will have to set build options (e.g. the " "directory) and defaultclock.dt again." ) self.project_dir = directory ensure_directory(directory) if os.path.isabs(results_directory): raise TypeError( "The 'results_directory' argument needs to be a relative path but was " f"'{results_directory}'." ) # Translate path to absolute path which ends with / self.results_dir = os.path.join( os.path.abspath(os.path.join(directory, results_directory)), "" ) compiler = "emcc" extra_compile_args = self.extra_compile_args + prefs["devices.wasm_standalone.emcc_compile_args"] extra_link_args = self.extra_link_args + prefs["devices.wasm_standalone.emcc_link_args"] define_macros = ( self.define_macros + prefs["codegen.cpp.define_macros"] + [m for c in self.code_objects.values() for m in c.compiler_kwds.get("define_macros", [])] ) include_dirs = ( self.include_dirs + prefs["codegen.cpp.include_dirs"] + [d for c in self.code_objects.values() for d in c.compiler_kwds.get("include_dirs", [])] ) library_dirs = ( self.library_dirs + prefs["codegen.cpp.library_dirs"] + [d for c in self.code_objects.values() for d in c.compiler_kwds.get("library_dirs", [])] ) # This library is only relevant when targetting Windows if "advapi32" in self.libraries: self.libraries.remove("advapi32") libraries = ( self.libraries + prefs["codegen.cpp.libraries"] + [l for c in self.code_objects.values() for l in c.compiler_kwds.get("libraries", [])] ) macro_flags = [] for m in define_macros: if isinstance(m, (list, tuple)): name, val = m if len(m) == 2 else (m[0], None) else: name, val = m, None macro_flags.append(f"-D{name}={val}" if val is not None else f"-D{name}") compiler_flags = ( extra_compile_args + macro_flags + [f"-I{d}" for d in include_dirs] ) linker_flags = ( extra_link_args + [f"-L{d}" for d in library_dirs] + [f"-l{l}" for l in libraries] ) additional_source_files += [ f for c in self.code_objects.values() for f in c.compiler_kwds.get("sources", []) ] for d in ("code_objects", "results", "static_arrays"): ensure_directory(os.path.join(directory, d)) self.writer = CPPWriter(directory) nb_threads = prefs.devices.cpp_standalone.openmp_threads if nb_threads < 0: raise ValueError("OpenMP threads cannot be negative.") self.check_openmp_compatible(nb_threads) self.write_static_arrays(directory) names = [o.name for n in self.networks for o in n.sorted_objects] dupes = [n for n, c in Counter(names).items() if c > 1] if dupes: raise ValueError("Duplicate object names: " + ", ".join(f"'{n}'" for n in dupes)) self.generate_objects_source(self.writer, self.arange_arrays, self.synapses, self.static_array_specs, self.networks, self.timed_arrays) self.generate_main_source(self.writer) self.generate_codeobj_source(self.writer) self.generate_network_source(self.writer, compiler) self.generate_synapses_classes_source(self.writer) self.generate_run_source(self.writer) self.copy_source_files(self.writer, directory) self.writer.source_files.update(additional_source_files) self.generate_makefile( self.writer, compiler, compiler_flags=" ".join(compiler_flags), linker_flags=" ".join(linker_flags), nb_threads=nb_threads, debug=debug, ) if compile: # We switch the compiler name back to `mscv` on Windows, to make sure it uses `nmake` self.compile_source(directory, 'msvc' if os.name == 'nt' else compiler, debug, clean) if run: self.run(directory, results_directory, with_output, run_args) tm = self.timers logger.debug("Time measurements: " + ", ".join( f"{lbl}: {tm[g][k]:.2f}s" if isinstance(tm[g], dict) else f"{lbl}: {tm[g]:.2f}s" for lbl, g, k in ( ("'make clean'", "compile", "clean"), ("'make'", "compile", "make"), ("running 'main'", "run_binary", None), ) if (k and tm[g][k] is not None) or (not k and tm[g] is not None) ))
wasm_standalone_device = WASMStandaloneDevice() all_devices['wasm_standalone'] = wasm_standalone_device