diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rwxr-xr-x | training/train_cnn.py | 35 | ||||
| -rw-r--r-- | training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl | 158 |
3 files changed, 192 insertions, 2 deletions
@@ -67,3 +67,4 @@ demo_timeline.html timeline.txt timeline.html Testing/ +training/checkpoints/ diff --git a/training/train_cnn.py b/training/train_cnn.py index c249947..82f0b48 100755 --- a/training/train_cnn.py +++ b/training/train_cnn.py @@ -246,9 +246,22 @@ def train(args): criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) + # Resume from checkpoint + start_epoch = 0 + if args.resume: + if os.path.exists(args.resume): + print(f"Loading checkpoint from {args.resume}...") + checkpoint = torch.load(args.resume, map_location=device) + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Resumed from epoch {start_epoch}") + else: + print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch") + # Training loop - print(f"\nTraining for {args.epochs} epochs...") - for epoch in range(args.epochs): + print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...") + for epoch in range(start_epoch, args.epochs): epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) @@ -265,6 +278,21 @@ def train(args): if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}") + # Save checkpoint + if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0: + checkpoint_dir = args.checkpoint_dir or 'training/checkpoints' + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') + torch.save({ + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict(), + 'loss': avg_loss, + 'kernel_sizes': kernel_sizes, + 'num_layers': args.layers + }, checkpoint_path) + print(f"Saved checkpoint to {checkpoint_path}") + # Export weights output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl' print(f"\nExporting weights to {output_path}...") @@ -284,6 +312,9 @@ def main(): parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)') parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)') + parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)') + parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)') + parser.add_argument('--resume', help='Resume from checkpoint file') args = parser.parse_args() diff --git a/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl new file mode 100644 index 0000000..dae81df --- /dev/null +++ b/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl @@ -0,0 +1,158 @@ +// Auto-generated CNN weights +// DO NOT EDIT - Generated by train_cnn.py + +const rgba_weights_layer0: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + -0.019696, -0.045138, -0.059372, + 0.113637, -0.026176, -0.204699, + -0.147723, -0.124720, -0.133641, + ), + mat4x4<f32>( + -0.011820, -0.110039, -0.019111, + 0.102596, -0.053469, -0.090972, + -0.106286, 0.062616, -0.211309, + ), + mat4x4<f32>( + 0.169672, -0.188668, 0.097992, + -0.048049, 0.035012, -0.028287, + 0.041841, 0.113846, 0.092006, + ), + mat4x4<f32>( + 0.084688, -0.173117, -0.130135, + 0.125052, 0.070060, -0.072493, + -0.081996, -0.041021, -0.200688, + ), + mat4x4<f32>( + 0.180550, 0.018555, -0.092889, + 0.105823, 0.109215, 0.042989, + -0.116116, 0.115354, 0.044726, + ), + mat4x4<f32>( + 0.069597, -0.156086, -0.116919, + 0.003641, -0.033090, 0.077686, + -0.090117, 0.047527, 0.093449, + ), + mat4x4<f32>( + -0.007961, -0.201232, -0.094087, + 0.041521, -0.001265, -0.164458, + -0.063295, -0.177367, 0.120887, + ), + mat4x4<f32>( + 0.005358, -0.153663, 0.234817, + 0.094452, -0.030598, -0.159715, + -0.025096, 0.010606, -0.151786, + ), + mat4x4<f32>( + 0.035922, 0.039006, -0.073426, + 0.234309, 0.042990, -0.074330, + 0.129497, -0.084083, -0.165691, + ) +); + +const coord_weights_layer0 = mat2x4<f32>( + 0.156995, -0.026005, 0.159550, + 0.112678, -0.021301, 0.106653 +); + +const bias_layer0 = vec4<f32>(0.149566, -0.002723, 0.142744); + +const weights_layer1: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + 0.198730, -0.060590, -0.126001, + 0.018094, 0.099855, 0.043531, + -0.048028, 0.024975, -0.055560, + ), + mat4x4<f32>( + 0.093012, -0.056168, 0.075685, + -0.104572, 0.202161, 0.093453, + 0.008470, 0.190414, -0.121853, + ), + mat4x4<f32>( + 0.157523, -0.278521, 0.267972, + 0.226318, 0.108021, -0.020615, + 0.116906, 0.094663, 0.103058, + ), + mat4x4<f32>( + 0.184815, -0.167385, -0.081513, + 0.167595, 0.147724, -0.034069, + 0.109272, 0.149283, 0.022741, + ), + mat4x4<f32>( + -0.133319, 0.069405, 0.028862, + -0.044914, -0.121720, 0.074758, + 0.150973, 0.086887, 0.193997, + ), + mat4x4<f32>( + 0.123384, -0.157817, -0.053264, + 0.216874, 0.024062, 0.227470, + 0.092232, 0.156942, 0.098989, + ), + mat4x4<f32>( + -0.074328, -0.265180, 0.065633, + 0.033679, 0.175748, 0.178567, + 0.168913, 0.192317, -0.015507, + ), + mat4x4<f32>( + -0.103567, -0.081663, 0.239707, + 0.020591, 0.031346, 0.089577, + -0.040636, 0.061481, 0.215428, + ), + mat4x4<f32>( + 0.103399, -0.291323, 0.220388, + 0.163876, 0.106383, 0.175615, + 0.050511, 0.210950, -0.143280, + ) +); + +const bias_layer1 = vec4<f32>(0.273340, 0.183151, 0.057200); + +const weights_layer2: array<mat4x4<f32>, 9> = array( + mat4x4<f32>( + 0.170688, -0.158379, -0.073057, + -0.213429, -0.075772, -0.117451, + -0.265536, -0.066896, 0.185188, + ), + mat4x4<f32>( + 0.061069, -0.267237, -0.057030, + -0.112682, -0.001723, 0.020779, + -0.158726, -0.027319, -0.133134, + ), + mat4x4<f32>( + -0.036597, 0.000282, -0.286058, + -0.056992, 0.129227, 0.037650, + -0.305341, -0.082011, 0.155333, + ), + mat4x4<f32>( + 0.146811, 0.086471, -0.092652, + -0.083987, -0.164501, 0.005801, + -0.108568, 0.079618, 0.011061, + ), + mat4x4<f32>( + 0.008716, -0.174373, 0.038516, + -0.263207, -0.201249, -0.106428, + -0.321199, 0.139540, -0.069047, + ), + mat4x4<f32>( + -0.099231, -0.037154, -0.189117, + 0.014380, 0.102996, 0.068944, + -0.011073, 0.175106, 0.019059, + ), + mat4x4<f32>( + -0.170030, -0.077528, -0.038504, + 0.042379, -0.198288, 0.008895, + -0.144090, -0.129658, 0.215823, + ), + mat4x4<f32>( + -0.082481, -0.160808, -0.279220, + -0.029358, 0.021159, -0.037080, + -0.194849, -0.013461, 0.057026, + ), + mat4x4<f32>( + -0.063711, -0.198759, -0.037847, + -0.049292, -0.222896, -0.067384, + -0.167766, -0.090320, 0.106986, + ) +); + +const bias_layer2 = vec4<f32>(0.021260, -0.056985, 0.000823); + |
