diff options
| author | skal <pascal.massimino@gmail.com> | 2026-02-10 07:36:32 +0100 |
|---|---|---|
| committer | skal <pascal.massimino@gmail.com> | 2026-02-10 07:36:32 +0100 |
| commit | c51c146da9590845b864cbba3a7317c5b5bed56a (patch) | |
| tree | 80fda2cad06622f367ae004527e4bea21d687e68 /doc | |
| parent | dcd52c3c595c1f37229b880fad11248b98bbced1 (diff) | |
initial doc for the CNN project
Diffstat (limited to 'doc')
| -rw-r--r-- | doc/CNN.md | 51 | ||||
| -rw-r--r-- | doc/CNN.py | 244 | ||||
| -rw-r--r-- | doc/CNN.shader | 35 |
3 files changed, 330 insertions, 0 deletions
diff --git a/doc/CNN.md b/doc/CNN.md new file mode 100644 index 0000000..8bf2860 --- /dev/null +++ b/doc/CNN.md @@ -0,0 +1,51 @@ +# Convolutional Neural Net Shader (CNN) post-processing + +## Idea + +Have the input 3d scene be processed by a multi-layer CNN trained on the side. +Input: some rendered scene. +Output: 'stylized' scene with CNN post-processing. + +## Shader implementation + +### input / output + +Need 1 texture buffer per CNN layer. +Input (r,g,b,1/z) for layer 0 (render 3d scene), or output from layer N-1 for layer N. +output: (r,g,b, alpha). Don't need the 1/z information (can be fetched from input) + +### size of one layer + +Notation: +S: the number of input samples from layer N-1. +Example: 3x3 input -> S = 3x3 = 9. + +Each S samples is 4 values (r,g,b, w=1/z). + +Each sample is processed by a mat4 matrix. 4 input => 4 output. + +Weight matrix = S x mat4 + +Final bias: 4 values. + +WGSL code example: See file CNN.shader + +### Layers + +we need 3 or 4 layer ? +Several different shaders for each layer. +Ping-pong for input/output texture buffer between each layers? + +## Training + +The layer weight/bias data are hard-coded in the shaders. +Need training with external python script. +File: CNN.py contains an example of what the training script could be. +Just an example, doesn't match our requirement yet. + +Need a repository of reference image pairs (before/after) for training and validation. +Each input image is randomly sampled into 3x3 patch of (r,g,b,1/z) input samples. +And trained to match the (r,g,b,a) output. + +Training generates the .wgsl code for layers' shaders, and the c++ code for the post-processing 'Effect'. + diff --git a/doc/CNN.py b/doc/CNN.py new file mode 100644 index 0000000..9952c97 --- /dev/null +++ b/doc/CNN.py @@ -0,0 +1,244 @@ +/* Python source code - Rory McHenry + +import tensorflow as tf +import numpy as np +from PIL import Image +import os + +learning_rate = 0.1 +training_iters = 50 +batch_size = 1 +display_step = 5 +W = 799 +H = 449 + +im = Image.open('min.png') +target = Image.open('mout.png') + +x = np.ones( [1, H,W,4] ) +x[0,:,:,0:3] = np.array(im)[:,:,0:3].astype(np.float32)/255 +y = np.ones( [1, H,W,4] ) +y[0,:,:,0:3] = np.array(target)[:,:,0:3].astype(np.float32)/255 + +x=tf.constant(x,dtype=tf.float32,shape=[1, H,W,4]) +y=tf.constant(y,dtype=tf.float32,shape=[1, H,W,4]) + +keep_prob = tf.placeholder(tf.float32) + +def conv2d(x, W, b, strides=1): + x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME') + x = tf.nn.bias_add(x, b) + return tf.tanh(x) + +def shaderNet(x, weights, biases, dropout): + + conv1 = conv2d(x , weights['wc1'], biases['bc1']) + conv2 = conv2d(conv1, weights['wc2'], biases['bc2']) + conv3 = conv2d(conv2, weights['wc3'], biases['bc3']) + conv4 = conv2d(conv3, weights['wc4'], biases['bc4']) + + return x + conv2d(conv4, weights['out'], biases['out']); + + + +weights = { + 'wc1': tf.Variable(tf.random_normal([3, 3, 4, 4], stddev=.1)), + 'wc2': tf.Variable(tf.random_normal([3, 3, 4, 4], stddev=.1)), + 'wc3': tf.Variable(tf.random_normal([3, 3, 4, 4], stddev=.1)), + 'wc4': tf.Variable(tf.random_normal([3, 3, 4, 4], stddev=.1)), + 'out': tf.Variable(tf.random_normal([3, 3, 4, 4], stddev=.1)) +} + +biases = { + 'bc1': tf.Variable(tf.random_normal([4], stddev=.01)), + 'bc2': tf.Variable(tf.random_normal([4], stddev=.01)), + 'bc3': tf.Variable(tf.random_normal([4], stddev=.01)), + 'bc4': tf.Variable(tf.random_normal([4], stddev=.01)), + 'out': tf.Variable(tf.random_normal([4], stddev=.01)) +} + +pred = shaderNet(x, weights, biases, keep_prob) + + +cost = tf.reduce_mean(tf.contrib.losses.mean_squared_error(pred , y )) +optimizer = tf.train.ProximalGradientDescentOptimizer(learning_rate=learning_rate).minimize(cost) + +correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) +accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) + +init = tf.global_variables_initializer() + +saver = tf.train.Saver() + +with tf.Session() as sess: + + + saver.restore(sess, 'C:/Users/rory/py/modelaa.ckpt') + + #sess.run(init) + + step = 1 + while step * batch_size < training_iters: + sess.run(optimizer) + if step % display_step == 0: + loss, acc = sess.run([cost, accuracy]) + print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ + "{:.6f}".format(loss) + ", Training Accuracy= " + \ + "{:.5f}".format(acc)) + + img = np.clip(sess.run(pred),0,1) + img = Image.fromarray((img[0,:,:,0:3]*255).astype(np.uint8), "RGB") + img.save(os.path.join(os.getcwd(),"out.png")) + + step += 1 + print("Optimization Finished!") + print("Testing Accuracy:", \ + sess.run(accuracy)) + + save_path = saver.save(sess, 'C:/Users/rory/py/modelaa.ckpt') + + w = sess.run(weights['wc1']) + b = sess.run(biases['bc1']) + + print( 'vec4 conv3x3(vec2 fragCoord) {') + print( 'vec4 res = vec4(0);') + for x in range(-1,2): + for y in range(-1,2): + print('res += im((fragCoord + vec2('+str(x)+','+str(y)+')) / iChannelResolution[0].xy)*mat4('+",".join(np.transpose(w[x+1,y+1]).reshape(16).astype(str))+');') + print( 'return res;') + print( '}') + + print( 'vec4 ReLU(vec4 i) {') + print(' return max(vec4(0),i);') + print('}') + + print('vec4 sigmoid(vec4 i) {') + print(' return 1./(1.+exp(-i));') + print('}') + + + print('vec4 tanh(vec4 i) {') + print(' return (exp(2.*i)-1.)/(exp(2.*i)+1.);') + print('}') + + print( 'void mainImage( out vec4 fragColor, in vec2 fragCoord )') + print('{') + print(' fragColor = tanh(conv3x3(fragCoord)+vec4('+",".join(b.astype(str))+'));') + print('}') + w = sess.run(weights['wc2']) + b = sess.run(biases['bc2']) + + print( 'vec4 conv3x3(vec2 fragCoord) {') + print( 'vec4 res = vec4(0);') + for x in range(-1,2): + for y in range(-1,2): + print('res += texture(iChannel0,(fragCoord + vec2('+str(x)+','+str(y)+')) / iChannelResolution[0].xy)*mat4('+",".join(np.transpose(w[x+1,y+1]).reshape(16).astype(str))+');') + print( 'return res;') + print( '}') + + print( 'vec4 ReLU(vec4 i) {') + print(' return max(vec4(0),i);') + print('}') + + print('vec4 sigmoid(vec4 i) {') + print(' return 1./(1.+exp(-i));') + print('}') + + + print('vec4 tanh(vec4 i) {') + print(' return (exp(2.*i)-1.)/(exp(2.*i)+1.);') + print('}') + + print( 'void mainImage( out vec4 fragColor, in vec2 fragCoord )') + print('{') + print(' fragColor = tanh(conv3x3(fragCoord)+vec4('+",".join(b.astype(str))+'));') + print('}') + w = sess.run(weights['wc3']) + b = sess.run(biases['bc3']) + + print( 'vec4 conv3x3(vec2 fragCoord) {') + print( 'vec4 res = vec4(0);') + for x in range(-1,2): + for y in range(-1,2): + print('res += texture(iChannel0,(fragCoord + vec2('+str(x)+','+str(y)+')) / iChannelResolution[0].xy)*mat4('+",".join(np.transpose(w[x+1,y+1]).reshape(16).astype(str))+');') + print( 'return res;') + print( '}') + + print( 'vec4 ReLU(vec4 i) {') + print(' return max(vec4(0),i);') + print('}') + + print('vec4 sigmoid(vec4 i) {') + print(' return 1./(1.+exp(-i));') + print('}') + + + print('vec4 tanh(vec4 i) {') + print(' return (exp(2.*i)-1.)/(exp(2.*i)+1.);') + print('}') + + print( 'void mainImage( out vec4 fragColor, in vec2 fragCoord )') + print('{') + print(' fragColor = tanh(conv3x3(fragCoord)+vec4('+",".join(b.astype(str))+'));') + print('}') + w = sess.run(weights['wc4']) + b = sess.run(biases['bc4']) + + print( 'vec4 conv3x3(vec2 fragCoord) {') + print( 'vec4 res = vec4(0);') + for x in range(-1,2): + for y in range(-1,2): + print('res += texture(iChannel0,(fragCoord + vec2('+str(x)+','+str(y)+')) / iChannelResolution[0].xy)*mat4('+",".join(np.transpose(w[x+1,y+1]).reshape(16).astype(str))+');') + print( 'return res;') + print( '}') + + print( 'vec4 ReLU(vec4 i) {') + print(' return max(vec4(0),i);') + print('}') + + print('vec4 sigmoid(vec4 i) {') + print(' return 1./(1.+exp(-i));') + print('}') + + + print('vec4 tanh(vec4 i) {') + print(' return (exp(2.*i)-1.)/(exp(2.*i)+1.);') + print('}') + + print( 'void mainImage( out vec4 fragColor, in vec2 fragCoord )') + print('{') + print(' fragColor = tanh(conv3x3(fragCoord)+vec4('+",".join(b.astype(str))+'));') + print('}') + + w = sess.run(weights['out']) + b = sess.run(biases['out']) + + print( 'vec4 conv3x3(vec2 fragCoord) {') + print( 'vec4 res = vec4(0);') + for x in range(-1,2): + for y in range(-1,2): + print('res += texture(iChannel0,(fragCoord + vec2('+str(x)+','+str(y)+')) / iChannelResolution[0].xy)*mat4('+",".join(np.transpose(w[x+1,y+1]).reshape(16).astype(str))+');') + print( 'return res;') + print( '}') + + print( 'vec4 ReLU(vec4 i) {') + print(' return max(vec4(0),i);') + print('}') + + print('vec4 sigmoid(vec4 i) {') + print(' return 1./(1.+exp(-i));') + print('}') + + + print('vec4 tanh(vec4 i) {') + print(' return (exp(2.*i)-1.)/(exp(2.*i)+1.);') + print('}') + + print( 'void mainImage( out vec4 fragColor, in vec2 fragCoord )') + print('{') + print(' fragColor = im(uv);') + print(' if(fragCoord.x>iMouse.x){') + print(' fragColor = tanh(conv3x3(fragCoord)+vec4('+",".join(b.astype(str))+'));') + print(' }') + print('}') +*/ diff --git a/doc/CNN.shader b/doc/CNN.shader new file mode 100644 index 0000000..e9418f4 --- /dev/null +++ b/doc/CNN.shader @@ -0,0 +1,35 @@ +// Example of CNN layer shader +// Input 3x3 x 4 in (r,g,b,z or 1/z) format. From previous layer +// Weights: 9 x 4x4 +// Bias: 4 +// Activation function: tanh or ReLU +// output: r,g,b, z or 1/z or alpha (TBD) + +vec4 conv3x3(vec2 pos) { + // input: (r,g,b, z or 1/z) + + vec2 iRes = 1. / iChannelResolution[0].xy; + + // Bias + vec4 res = vec4(-0.04889335483312607,-0.05943099409341812,0.014945696108043194,0.0038716429844498634); + // Weights + res += texture(iChannel0,(pos + vec2(-1,-1)) * iRes)*mat4(-0.22058571875095367,-0.0025178513024002314,-0.02137291617691517,0.0755973681807518,-0.07658655941486359,-0.15938608348369598,0.039643868803977966,-0.012259022332727909,-0.015218861401081085,-0.050223562866449356,-0.07961801439523697,0.11616108566522598,0.13200008869171143,0.03162014111876488,0.032535750418901443,-0.10636500269174576); + res += texture(iChannel0,(pos + vec2(-1,0)) * iRes)*mat4(-0.28930965065956116,-0.11228568106889725,-0.08608853071928024,0.11449871212244034,0.12511080503463745,-0.10040754824876785,0.17461593449115753,-0.15175022184848785,0.03787801042199135,0.20102104544639587,-0.024612706154584885,0.02710619568824768,-0.06153976172208786,-0.10482363402843475,-0.014178688637912273,0.023371122777462006); + res += texture(iChannel0,(pos + vec2(-1,1)) * iRes)*mat4(-0.1525852084159851,0.012323809787631035,-0.04394780099391937,-0.07254716753959656,0.18465806543827057,0.14453156292438507,-0.07343120872974396,0.2724604606628418,0.03352152556180954,0.04368482530117035,-0.0542469397187233,0.1053997203707695,-0.04070863872766495,0.0843065083026886,-0.042356643825769424,0.17000557482242584); + res += texture(iChannel0,(pos + vec2(0,-1)) * iRes)*mat4(-0.13294945657253265,-0.1913488656282425,0.17023856937885284,0.25633060932159424,-0.10361425578594208,-0.10409805923700333,-0.02211007848381996,0.19673652946949005,0.1772589534521103,0.1924116164445877,0.08171508461236954,0.033589430153369904,-0.3019593060016632,-0.4145629405975342,-0.2238256186246872,-0.033089861273765564); + res += texture(iChannel0,(pos + vec2(0,0)) * iRes)*mat4(-0.10532718896865845,-0.009120185859501362,0.1831541657447815,-0.11598552763462067,-0.10732044279575348,-0.1422705501317978,-0.2938171923160553,0.08497025817632675,0.07730063796043396,-0.0065005188807845116,0.06287281960248947,-0.09081853181123734,-0.2506500482559204,0.007432099897414446,-0.05117560178041458,-0.12751594185829163); + res += texture(iChannel0,(pos + vec2(0,1)) * iRes)*mat4(-0.031646713614463806,0.03424060717225075,0.02292831800878048,-0.013373860158026218,0.2722923159599304,0.15562108159065247,0.08641268312931061,0.013247879222035408,0.2955344021320343,0.2007416933774948,-0.03226592391729355,0.0658501535654068,0.11414589732885361,0.152848482131958,0.12652148306369781,0.0672551840543747); + res += texture(iChannel0,(pos + vec2(1,-1)) * iRes)*mat4(0.22747460007667542,-0.03772164508700371,0.30145782232284546,0.04817598685622215,0.16727609932422638,-0.02036409080028534,0.1638275533914566,0.1533416360616684,0.2211352437734604,-0.02886020578444004,-0.08635003119707108,0.025351224467158318,-0.35834380984306335,0.01970680058002472,-0.013621831312775612,0.14156390726566315); + res += texture(iChannel0,(pos + vec2(1,0)) * iRes)*mat4(-0.296324759721756,0.17407724261283875,-0.04902844503521919,0.023473504930734634,-0.26604920625686646,-0.1855679303407669,0.3079718053340912,-0.049569256603717804,-0.30711254477500916,0.05227816477417946,-0.1393774002790451,0.12021080404520035,0.21768616139888763,0.2681339681148529,-0.09689617156982422,-0.21676960587501526); + res += texture(iChannel0,(pos + vec2(1,1)) * iRes)*mat4(0.13681215047836304,-0.05690794438123703,0.07499313354492188,0.1611005961894989,-0.18945586681365967,0.011663767509162426,0.10834614187479019,-0.04871741682291031,-0.38159874081611633,-0.15830495953559875,0.22751332819461823,0.170019268989563,-0.1293516904115677,-0.021397801116108894,0.24992097914218903,0.03157894313335419); + // Activation function + res = tanh(res); + //res = max(vec4(0.), res); + // output (r,g,b, z) + return res; +} + +void mainImage( out vec4 fragColor, in vec2 fragCoord ) { + vec4 out_1 = conv3x3(fragCoord); + fragColor = vec4(out_1.rgb, .0); +} |
