summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rwxr-xr-xtraining/train_cnn.py35
-rw-r--r--training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl158
3 files changed, 192 insertions, 2 deletions
diff --git a/.gitignore b/.gitignore
index a9afb7d..7a9ab42 100644
--- a/.gitignore
+++ b/.gitignore
@@ -67,3 +67,4 @@ demo_timeline.html
timeline.txt
timeline.html
Testing/
+training/checkpoints/
diff --git a/training/train_cnn.py b/training/train_cnn.py
index c249947..82f0b48 100755
--- a/training/train_cnn.py
+++ b/training/train_cnn.py
@@ -246,9 +246,22 @@ def train(args):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+ # Resume from checkpoint
+ start_epoch = 0
+ if args.resume:
+ if os.path.exists(args.resume):
+ print(f"Loading checkpoint from {args.resume}...")
+ checkpoint = torch.load(args.resume, map_location=device)
+ model.load_state_dict(checkpoint['model_state'])
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
+ start_epoch = checkpoint['epoch'] + 1
+ print(f"Resumed from epoch {start_epoch}")
+ else:
+ print(f"Warning: Checkpoint file '{args.resume}' not found, starting from scratch")
+
# Training loop
- print(f"\nTraining for {args.epochs} epochs...")
- for epoch in range(args.epochs):
+ print(f"\nTraining for {args.epochs} epochs (starting from epoch {start_epoch})...")
+ for epoch in range(start_epoch, args.epochs):
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs, targets = inputs.to(device), targets.to(device)
@@ -265,6 +278,21 @@ def train(args):
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.6f}")
+ # Save checkpoint
+ if args.checkpoint_every > 0 and (epoch + 1) % args.checkpoint_every == 0:
+ checkpoint_dir = args.checkpoint_dir or 'training/checkpoints'
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
+ torch.save({
+ 'epoch': epoch,
+ 'model_state': model.state_dict(),
+ 'optimizer_state': optimizer.state_dict(),
+ 'loss': avg_loss,
+ 'kernel_sizes': kernel_sizes,
+ 'num_layers': args.layers
+ }, checkpoint_path)
+ print(f"Saved checkpoint to {checkpoint_path}")
+
# Export weights
output_path = args.output or 'workspaces/main/shaders/cnn/cnn_weights_generated.wgsl'
print(f"\nExporting weights to {output_path}...")
@@ -284,6 +312,9 @@ def main():
parser.add_argument('--batch_size', type=int, default=4, help='Batch size (default: 4)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')
parser.add_argument('--output', help='Output WGSL file path (default: workspaces/main/shaders/cnn/cnn_weights_generated.wgsl)')
+ parser.add_argument('--checkpoint-every', type=int, default=0, help='Save checkpoint every N epochs (default: 0 = disabled)')
+ parser.add_argument('--checkpoint-dir', help='Checkpoint directory (default: training/checkpoints)')
+ parser.add_argument('--resume', help='Resume from checkpoint file')
args = parser.parse_args()
diff --git a/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl b/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
new file mode 100644
index 0000000..dae81df
--- /dev/null
+++ b/training/workspaces/main/shaders/cnn/cnn_weights_generated.wgsl
@@ -0,0 +1,158 @@
+// Auto-generated CNN weights
+// DO NOT EDIT - Generated by train_cnn.py
+
+const rgba_weights_layer0: array<mat4x4<f32>, 9> = array(
+ mat4x4<f32>(
+ -0.019696, -0.045138, -0.059372,
+ 0.113637, -0.026176, -0.204699,
+ -0.147723, -0.124720, -0.133641,
+ ),
+ mat4x4<f32>(
+ -0.011820, -0.110039, -0.019111,
+ 0.102596, -0.053469, -0.090972,
+ -0.106286, 0.062616, -0.211309,
+ ),
+ mat4x4<f32>(
+ 0.169672, -0.188668, 0.097992,
+ -0.048049, 0.035012, -0.028287,
+ 0.041841, 0.113846, 0.092006,
+ ),
+ mat4x4<f32>(
+ 0.084688, -0.173117, -0.130135,
+ 0.125052, 0.070060, -0.072493,
+ -0.081996, -0.041021, -0.200688,
+ ),
+ mat4x4<f32>(
+ 0.180550, 0.018555, -0.092889,
+ 0.105823, 0.109215, 0.042989,
+ -0.116116, 0.115354, 0.044726,
+ ),
+ mat4x4<f32>(
+ 0.069597, -0.156086, -0.116919,
+ 0.003641, -0.033090, 0.077686,
+ -0.090117, 0.047527, 0.093449,
+ ),
+ mat4x4<f32>(
+ -0.007961, -0.201232, -0.094087,
+ 0.041521, -0.001265, -0.164458,
+ -0.063295, -0.177367, 0.120887,
+ ),
+ mat4x4<f32>(
+ 0.005358, -0.153663, 0.234817,
+ 0.094452, -0.030598, -0.159715,
+ -0.025096, 0.010606, -0.151786,
+ ),
+ mat4x4<f32>(
+ 0.035922, 0.039006, -0.073426,
+ 0.234309, 0.042990, -0.074330,
+ 0.129497, -0.084083, -0.165691,
+ )
+);
+
+const coord_weights_layer0 = mat2x4<f32>(
+ 0.156995, -0.026005, 0.159550,
+ 0.112678, -0.021301, 0.106653
+);
+
+const bias_layer0 = vec4<f32>(0.149566, -0.002723, 0.142744);
+
+const weights_layer1: array<mat4x4<f32>, 9> = array(
+ mat4x4<f32>(
+ 0.198730, -0.060590, -0.126001,
+ 0.018094, 0.099855, 0.043531,
+ -0.048028, 0.024975, -0.055560,
+ ),
+ mat4x4<f32>(
+ 0.093012, -0.056168, 0.075685,
+ -0.104572, 0.202161, 0.093453,
+ 0.008470, 0.190414, -0.121853,
+ ),
+ mat4x4<f32>(
+ 0.157523, -0.278521, 0.267972,
+ 0.226318, 0.108021, -0.020615,
+ 0.116906, 0.094663, 0.103058,
+ ),
+ mat4x4<f32>(
+ 0.184815, -0.167385, -0.081513,
+ 0.167595, 0.147724, -0.034069,
+ 0.109272, 0.149283, 0.022741,
+ ),
+ mat4x4<f32>(
+ -0.133319, 0.069405, 0.028862,
+ -0.044914, -0.121720, 0.074758,
+ 0.150973, 0.086887, 0.193997,
+ ),
+ mat4x4<f32>(
+ 0.123384, -0.157817, -0.053264,
+ 0.216874, 0.024062, 0.227470,
+ 0.092232, 0.156942, 0.098989,
+ ),
+ mat4x4<f32>(
+ -0.074328, -0.265180, 0.065633,
+ 0.033679, 0.175748, 0.178567,
+ 0.168913, 0.192317, -0.015507,
+ ),
+ mat4x4<f32>(
+ -0.103567, -0.081663, 0.239707,
+ 0.020591, 0.031346, 0.089577,
+ -0.040636, 0.061481, 0.215428,
+ ),
+ mat4x4<f32>(
+ 0.103399, -0.291323, 0.220388,
+ 0.163876, 0.106383, 0.175615,
+ 0.050511, 0.210950, -0.143280,
+ )
+);
+
+const bias_layer1 = vec4<f32>(0.273340, 0.183151, 0.057200);
+
+const weights_layer2: array<mat4x4<f32>, 9> = array(
+ mat4x4<f32>(
+ 0.170688, -0.158379, -0.073057,
+ -0.213429, -0.075772, -0.117451,
+ -0.265536, -0.066896, 0.185188,
+ ),
+ mat4x4<f32>(
+ 0.061069, -0.267237, -0.057030,
+ -0.112682, -0.001723, 0.020779,
+ -0.158726, -0.027319, -0.133134,
+ ),
+ mat4x4<f32>(
+ -0.036597, 0.000282, -0.286058,
+ -0.056992, 0.129227, 0.037650,
+ -0.305341, -0.082011, 0.155333,
+ ),
+ mat4x4<f32>(
+ 0.146811, 0.086471, -0.092652,
+ -0.083987, -0.164501, 0.005801,
+ -0.108568, 0.079618, 0.011061,
+ ),
+ mat4x4<f32>(
+ 0.008716, -0.174373, 0.038516,
+ -0.263207, -0.201249, -0.106428,
+ -0.321199, 0.139540, -0.069047,
+ ),
+ mat4x4<f32>(
+ -0.099231, -0.037154, -0.189117,
+ 0.014380, 0.102996, 0.068944,
+ -0.011073, 0.175106, 0.019059,
+ ),
+ mat4x4<f32>(
+ -0.170030, -0.077528, -0.038504,
+ 0.042379, -0.198288, 0.008895,
+ -0.144090, -0.129658, 0.215823,
+ ),
+ mat4x4<f32>(
+ -0.082481, -0.160808, -0.279220,
+ -0.029358, 0.021159, -0.037080,
+ -0.194849, -0.013461, 0.057026,
+ ),
+ mat4x4<f32>(
+ -0.063711, -0.198759, -0.037847,
+ -0.049292, -0.222896, -0.067384,
+ -0.167766, -0.090320, 0.106986,
+ )
+);
+
+const bias_layer2 = vec4<f32>(0.021260, -0.056985, 0.000823);
+