Skip to content

Commit 123bfc0

Browse files
Merge pull request #2 from edowson/master
Fixes for Python3 and Keras-2.3.1
2 parents 0460420 + 861f8b7 commit 123bfc0

File tree

3 files changed

+10
-15
lines changed

3 files changed

+10
-15
lines changed

complexnn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#
77
# What this module includes by default:
88
from . import bn, conv, dense, init, norm, pool
9-
# from . import fft
9+
from . import fft
1010

1111
from .bn import ComplexBatchNormalization as ComplexBN
1212
from .conv import (
@@ -17,7 +17,7 @@
1717
WeightNorm_Conv,
1818
)
1919
from .dense import ComplexDense
20-
# from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
20+
from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
2121
from .init import (
2222
ComplexIndependentFilters,
2323
IndependentFilters,

scripts/run.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
import os, pdb, sys
1111
import time
1212

13-
__version__ = "0.0.0"
14-
15-
16-
1713
#
1814
# Message Formatter
1915
#
@@ -26,7 +22,7 @@ class MsgFormatter(L.Formatter):
2622

2723
def formatTime(self, record, datefmt):
2824
t = record.created
29-
timeFrac = abs(t-long(t))
25+
timeFrac = abs(t-int(t))
3026
timeStruct = time.localtime(record.created)
3127
timeString = ""
3228
timeString += time.strftime("%F %T", timeStruct)
@@ -84,7 +80,7 @@ def addArgs(cls, argp):
8480
argp.add_argument("-l", "--loglevel", default="info", type=str,
8581
choices=cls.LOGLEVELS.keys(),
8682
help="Logging severity level.")
87-
argp.add_argument("-s", "--seed", default=0xe4223644e98b8e64, type=long,
83+
argp.add_argument("-s", "--seed", default=0xe4223644e98b8e64, type=int,
8884
help="Seed for PRNGs.")
8985
argp.add_argument("--summary", action="store_true",
9086
help="""Print a summary of the network.""")
@@ -196,8 +192,7 @@ def getArgParser(prog):
196192
argp = Ap.ArgumentParser(prog = prog,
197193
usage = None,
198194
description = None,
199-
epilog = None,
200-
version = __version__)
195+
epilog = None)
201196
subp = argp.add_subparsers()
202197
argp.set_defaults(argp=argp)
203198
argp.set_defaults(subp=subp)
@@ -207,7 +202,7 @@ def getArgParser(prog):
207202

208203

209204
# Add subcommands
210-
for v in globals().itervalues():
205+
for v in globals().values():
211206
if(isinstance(v, type) and
212207
issubclass(v, Subcommand) and
213208
v != Subcommand):

scripts/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def getResnetModel(d):
150150
activation = d.act
151151
advanced_act = d.aact
152152
drop_prob = d.dropout
153-
inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3)
153+
inputShape = (3, 32, 32) if K.image_data_format() == "channels_first" else (32, 32, 3)
154154
channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
155155
filsize = (3, 3)
156156
convArgs = {
@@ -196,7 +196,7 @@ def getResnetModel(d):
196196
# Stage 2
197197
#
198198

199-
for i in xrange(n):
199+
for i in range(n):
200200
O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular', convArgs, bnArgs, d)
201201
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
202202
O = applySpectralPooling(O, d)
@@ -209,7 +209,7 @@ def getResnetModel(d):
209209
if d.spectral_pool_scheme == "nodownsample":
210210
O = applySpectralPooling(O, d)
211211

212-
for i in xrange(n-1):
212+
for i in range(n-1):
213213
O = getResidualBlock(O, filsize, [sf*2, sf*2], 3, str(i+1), 'regular', convArgs, bnArgs, d)
214214
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
215215
O = applySpectralPooling(O, d)
@@ -222,7 +222,7 @@ def getResnetModel(d):
222222
if d.spectral_pool_scheme == "nodownsample":
223223
O = applySpectralPooling(O, d)
224224

225-
for i in xrange(n-1):
225+
for i in range(n-1):
226226
O = getResidualBlock(O, filsize, [sf*4, sf*4], 4, str(i+1), 'regular', convArgs, bnArgs, d)
227227
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
228228
O = applySpectralPooling(O, d)

0 commit comments

Comments
 (0)