diff options
Diffstat (limited to 'training/export_cnn_v2_weights.py')
| -rwxr-xr-x | training/export_cnn_v2_weights.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/training/export_cnn_v2_weights.py b/training/export_cnn_v2_weights.py index 07254fc..bbe94dd 100755 --- a/training/export_cnn_v2_weights.py +++ b/training/export_cnn_v2_weights.py @@ -45,11 +45,20 @@ def export_weights_binary(checkpoint_path, output_path): state_dict = checkpoint['model_state_dict'] config = checkpoint['config'] - kernel_size = config.get('kernel_size', 3) - num_layers = config.get('num_layers', 3) + # Support both old (kernel_size) and new (kernel_sizes) format + if 'kernel_sizes' in config: + kernel_sizes = config['kernel_sizes'] + elif 'kernel_size' in config: + kernel_size = config['kernel_size'] + num_layers = config.get('num_layers', 3) + kernel_sizes = [kernel_size] * num_layers + else: + kernel_sizes = [3, 3, 3] # fallback + + num_layers = config.get('num_layers', len(kernel_sizes)) print(f"Configuration:") - print(f" Kernel size: {kernel_size}×{kernel_size}") + print(f" Kernel sizes: {kernel_sizes}") print(f" Layers: {num_layers}") print(f" Architecture: uniform 12D→4D (bias=False)") @@ -65,6 +74,7 @@ def export_weights_binary(checkpoint_path, output_path): layer_weights = state_dict[layer_key].detach().numpy() layer_flat = layer_weights.flatten() + kernel_size = kernel_sizes[i] layers.append({ 'kernel_size': kernel_size, @@ -76,7 +86,7 @@ def export_weights_binary(checkpoint_path, output_path): all_weights.extend(layer_flat) weight_offset += len(layer_flat) - print(f" Layer {i}: 12D→4D, {len(layer_flat)} weights") + print(f" Layer {i}: 12D→4D, {kernel_size}×{kernel_size}, {len(layer_flat)} weights") # Convert to f16 # TODO: Use 8-bit quantization for 2× size reduction |
